tpu-inference 0.11.1.dev202511150811__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +0 -0
- tests/core/__init__.py +0 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +53 -0
- tests/core/test_dp_scheduler.py +899 -0
- tests/core/test_init.py +49 -0
- tests/kernels/__init__.py +0 -0
- tests/kernels/fused_moe_v1_test.py +105 -0
- tests/kernels/mla_v1_test.py +396 -0
- tests/kernels/quantized_matmul_kernel_test.py +191 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
- tests/lora/__init__.py +0 -0
- tests/lora/conftest.py +32 -0
- tests/lora/test_bgmv.py +43 -0
- tests/lora/test_layers.py +654 -0
- tests/lora/test_lora.py +133 -0
- tests/lora/utils.py +96 -0
- tests/test_base.py +201 -0
- tests/test_envs.py +182 -0
- tests/test_quantization.py +836 -0
- tests/test_tpu_info.py +120 -0
- tests/test_utils.py +236 -0
- tpu_inference/__init__.py +34 -0
- tpu_inference/core/__init__.py +0 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +51 -0
- tpu_inference/core/sched/__init__.py +0 -0
- tpu_inference/core/sched/dp_scheduler.py +523 -0
- tpu_inference/distributed/__init__.py +0 -0
- tpu_inference/distributed/jax_parallel_state.py +67 -0
- tpu_inference/distributed/tpu_connector.py +728 -0
- tpu_inference/distributed/utils.py +59 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +107 -0
- tpu_inference/executors/__init__.py +0 -0
- tpu_inference/executors/ray_distributed_executor.py +362 -0
- tpu_inference/experimental/__init__.py +0 -0
- tpu_inference/experimental/llama3_jax_stashed.py +258 -0
- tpu_inference/kernels/__init__.py +0 -0
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +0 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
- tpu_inference/kernels/mla/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/__init__.py +0 -0
- tpu_inference/kernels/mla/v1/kernel.py +1349 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_interface.py +390 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/common/sharding.py +582 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +255 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +280 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +96 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
- tpu_inference/layers/jax/transformer_block.py +107 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +507 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +39 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
- tpu_inference/layers/vllm/sharding.py +230 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +0 -0
- tpu_inference/lora/torch_lora_ops.py +103 -0
- tpu_inference/lora/torch_punica_tpu.py +311 -0
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/__init__.py +0 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +444 -0
- tpu_inference/models/jax/__init__.py +0 -0
- tpu_inference/models/jax/deepseek_v3.py +868 -0
- tpu_inference/models/jax/gpt_oss.py +492 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
- tpu_inference/models/jax/llama3.py +375 -0
- tpu_inference/models/jax/llama4.py +629 -0
- tpu_inference/models/jax/llama_eagle3.py +333 -0
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +375 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
- tpu_inference/models/jax/qwen3.py +302 -0
- tpu_inference/models/jax/utils/__init__.py +0 -0
- tpu_inference/models/jax/utils/file_utils.py +96 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
- tpu_inference/models/jax/utils/weight_utils.py +529 -0
- tpu_inference/models/vllm/__init__.py +0 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
- tpu_inference/platforms/__init__.py +2 -0
- tpu_inference/platforms/tpu_platform.py +269 -0
- tpu_inference/runner/__init__.py +0 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +780 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +132 -0
- tpu_inference/runner/kv_cache_manager.py +479 -0
- tpu_inference/runner/lora_utils.py +92 -0
- tpu_inference/runner/multimodal_manager.py +217 -0
- tpu_inference/runner/persistent_batch_manager.py +244 -0
- tpu_inference/runner/speculative_decoding_manager.py +248 -0
- tpu_inference/runner/structured_decoding_manager.py +88 -0
- tpu_inference/runner/tpu_runner.py +1620 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +0 -0
- tpu_inference/spec_decode/jax/__init__.py +0 -0
- tpu_inference/spec_decode/jax/eagle3.py +367 -0
- tpu_inference/tpu_info.py +77 -0
- tpu_inference/utils.py +317 -0
- tpu_inference/worker/__init__.py +0 -0
- tpu_inference/worker/tpu_worker.py +321 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
from dataclasses import InitVar, dataclass
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from flax import nnx
|
|
6
|
+
from flax.typing import Sharding
|
|
7
|
+
from jaxtyping import Float
|
|
8
|
+
|
|
9
|
+
from tpu_inference.layers.jax.base import create_param
|
|
10
|
+
from tpu_inference.layers.jax.layers import FlaxUtils
|
|
11
|
+
from tpu_inference.layers.jax.moe.moe import Router
|
|
12
|
+
|
|
13
|
+
modeling_flax_utils = FlaxUtils()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(kw_only=True)
|
|
17
|
+
class GptOssRouter(Router):
|
|
18
|
+
"""Router module for Mixture-of-Experts (MoE) layers.
|
|
19
|
+
|
|
20
|
+
This module determines which experts each token should be routed.
|
|
21
|
+
|
|
22
|
+
"""
|
|
23
|
+
e_sharding: Sharding = ()
|
|
24
|
+
|
|
25
|
+
def __post_init__(self, rngs: nnx.Rngs):
|
|
26
|
+
"""
|
|
27
|
+
Initializes the parent's kernel and adds the new bias parameter.
|
|
28
|
+
"""
|
|
29
|
+
super().__post_init__(rngs)
|
|
30
|
+
|
|
31
|
+
self.bias_E = create_param(rngs,
|
|
32
|
+
shape=(self.num_experts, ),
|
|
33
|
+
dtype=self.dtype,
|
|
34
|
+
sharding=self.e_sharding,
|
|
35
|
+
random_init=self.random_init)
|
|
36
|
+
|
|
37
|
+
def __call__(self, x_TD: Float):
|
|
38
|
+
"""
|
|
39
|
+
Overrides the parent's forward pass to include the bias.
|
|
40
|
+
"""
|
|
41
|
+
x_TD = jnp.asarray(x_TD, self.dtype)
|
|
42
|
+
x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
|
|
43
|
+
|
|
44
|
+
router_logits_TE = jnp.einsum('TD,DE -> TE', x_TD,
|
|
45
|
+
self.kernel_DE.value)
|
|
46
|
+
|
|
47
|
+
router_logits_TE += self.bias_E.value
|
|
48
|
+
|
|
49
|
+
weights_TX, selected_experts_TX = jax.lax.top_k(
|
|
50
|
+
router_logits_TE, self.num_experts_per_tok)
|
|
51
|
+
|
|
52
|
+
normalized_weights_TX = jax.nn.softmax(weights_TX.astype(self.dtype),
|
|
53
|
+
axis=-1)
|
|
54
|
+
|
|
55
|
+
return normalized_weights_TX, selected_experts_TX
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _swiglu(x: Float, alpha: Float, limit: Float) -> Float:
|
|
59
|
+
"""Implements the specific SwiGLU from the golden implementation."""
|
|
60
|
+
x_glu, x_linear = x[..., ::2], x[..., 1::2]
|
|
61
|
+
|
|
62
|
+
x_glu = jnp.clip(x_glu, a_max=limit)
|
|
63
|
+
x_linear = jnp.clip(x_linear, a_min=-limit, a_max=limit)
|
|
64
|
+
|
|
65
|
+
gated_activation = x_glu * jax.nn.sigmoid(alpha * x_glu)
|
|
66
|
+
|
|
67
|
+
return gated_activation * (x_linear + 1)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@dataclass(kw_only=True)
|
|
71
|
+
class CombineExperts(nnx.Module):
|
|
72
|
+
"""Module for combining expert outputs with weighted sum."""
|
|
73
|
+
dtype: jnp.dtype
|
|
74
|
+
|
|
75
|
+
def __call__(self, down_proj_TED: Float, weights_TX: Float,
|
|
76
|
+
indices_TX: jax.Array) -> Float:
|
|
77
|
+
"""Combines expert outputs using weighted sum.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
down_proj_TED: Expert outputs, shape (tokens, experts, hidden_dim)
|
|
81
|
+
weights_TX: Router weights, shape (tokens, experts_per_token)
|
|
82
|
+
indices_TX: Selected expert indices, shape (tokens, experts_per_token)
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
Combined output, shape (tokens, hidden_dim)
|
|
86
|
+
"""
|
|
87
|
+
with jax.named_scope("combine_experts"):
|
|
88
|
+
indices_for_gather = indices_TX[..., None]
|
|
89
|
+
gathered_down_proj_TED = jnp.take_along_axis(down_proj_TED,
|
|
90
|
+
indices_for_gather,
|
|
91
|
+
axis=1)
|
|
92
|
+
output_TD = jnp.einsum('TXD,TX -> TD', gathered_down_proj_TED,
|
|
93
|
+
weights_TX)
|
|
94
|
+
|
|
95
|
+
return output_TD.astype(self.dtype)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
@dataclass(kw_only=True)
|
|
99
|
+
class GptOssMoE(nnx.Module):
|
|
100
|
+
"""
|
|
101
|
+
JAX implementation of the GPT-OSS Mixture-of-Experts MLP block.
|
|
102
|
+
"""
|
|
103
|
+
dtype: jnp.dtype
|
|
104
|
+
hidden_size: int
|
|
105
|
+
intermediate_size_moe: int
|
|
106
|
+
num_local_experts: int
|
|
107
|
+
router: GptOssRouter
|
|
108
|
+
rngs: InitVar[nnx.Rngs]
|
|
109
|
+
|
|
110
|
+
swiglu_limit: float = 7.0
|
|
111
|
+
swiglu_alpha: float = 1.702
|
|
112
|
+
|
|
113
|
+
# Sharding specifications
|
|
114
|
+
activation_ffw_td: Sharding
|
|
115
|
+
edf_sharding: Sharding
|
|
116
|
+
efd_sharding: Sharding
|
|
117
|
+
ed_sharding: Sharding
|
|
118
|
+
|
|
119
|
+
random_init: bool = False
|
|
120
|
+
|
|
121
|
+
def __call__(self, x_TD: Float) -> Float:
|
|
122
|
+
"""Performs the forward pass for the GPT-OSS MoE layer."""
|
|
123
|
+
x_TD = jnp.asarray(x_TD, self.dtype)
|
|
124
|
+
x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
|
|
125
|
+
|
|
126
|
+
weights_TX, indices_TX = self.router(x_TD)
|
|
127
|
+
|
|
128
|
+
# First MLP layer (up-projection)
|
|
129
|
+
with jax.named_scope("MLP #1"):
|
|
130
|
+
up_proj_TEF2 = jnp.einsum('TD,EDF -> TEF', x_TD,
|
|
131
|
+
self.mlp1_weight_EDF2.value)
|
|
132
|
+
up_proj_TEF2 += self.mlp1_bias_EF2.value
|
|
133
|
+
|
|
134
|
+
fuse_TEF = _swiglu(up_proj_TEF2,
|
|
135
|
+
alpha=self.swiglu_alpha,
|
|
136
|
+
limit=self.swiglu_limit)
|
|
137
|
+
|
|
138
|
+
# Second MLP layer (down-projection)
|
|
139
|
+
with jax.named_scope("MLP #2"):
|
|
140
|
+
down_proj_TED = jnp.einsum('TEF,EFD -> TED', fuse_TEF,
|
|
141
|
+
self.mlp2_weight_EFD.value)
|
|
142
|
+
down_proj_TED += self.mlp2_bias_ED.value
|
|
143
|
+
|
|
144
|
+
# Weighted sum of expert outputs
|
|
145
|
+
output_TD = self.combine_experts(down_proj_TED, weights_TX, indices_TX)
|
|
146
|
+
|
|
147
|
+
return output_TD
|
|
148
|
+
|
|
149
|
+
def __post_init__(self, rngs: nnx.Rngs):
|
|
150
|
+
"""Initializes all weights and biases for the MoE block."""
|
|
151
|
+
D, F, E = self.hidden_size, self.intermediate_size_moe, self.num_local_experts
|
|
152
|
+
|
|
153
|
+
self.combine_experts = CombineExperts(dtype=self.dtype)
|
|
154
|
+
|
|
155
|
+
# MLP #1 Weights (Combined Gate and Up-projection) and Bias
|
|
156
|
+
self.mlp1_weight_EDF2 = create_param(
|
|
157
|
+
rngs,
|
|
158
|
+
shape=(E, D, F * 2),
|
|
159
|
+
dtype=self.dtype,
|
|
160
|
+
sharding=self.edf_sharding,
|
|
161
|
+
random_init=self.random_init,
|
|
162
|
+
)
|
|
163
|
+
self.mlp1_bias_EF2 = create_param(
|
|
164
|
+
rngs,
|
|
165
|
+
shape=(E, F * 2),
|
|
166
|
+
dtype=self.dtype,
|
|
167
|
+
sharding=self.ed_sharding,
|
|
168
|
+
random_init=self.random_init,
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
# MLP #2 Weights (Down-projection) and Bias
|
|
172
|
+
self.mlp2_weight_EFD = create_param(
|
|
173
|
+
rngs,
|
|
174
|
+
shape=(E, F, D),
|
|
175
|
+
dtype=self.dtype,
|
|
176
|
+
sharding=self.efd_sharding,
|
|
177
|
+
random_init=self.random_init,
|
|
178
|
+
)
|
|
179
|
+
self.mlp2_bias_ED = create_param(
|
|
180
|
+
rngs,
|
|
181
|
+
shape=(E, D),
|
|
182
|
+
dtype=self.dtype,
|
|
183
|
+
sharding=self.ed_sharding,
|
|
184
|
+
random_init=self.random_init,
|
|
185
|
+
)
|
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
from dataclasses import InitVar, dataclass
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from flax import nnx
|
|
6
|
+
from flax.typing import Sharding
|
|
7
|
+
from jaxtyping import Float
|
|
8
|
+
|
|
9
|
+
from tpu_inference.layers.jax.base import create_param
|
|
10
|
+
from tpu_inference.layers.jax.layers import FlaxUtils
|
|
11
|
+
|
|
12
|
+
modeling_flax_utils = FlaxUtils()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(kw_only=True)
|
|
16
|
+
class Router(nnx.Module):
|
|
17
|
+
"""Router module for Mixture-of-Experts (MoE) layers.
|
|
18
|
+
|
|
19
|
+
This module determines which experts each token should be routed to based on the input.
|
|
20
|
+
|
|
21
|
+
Attributes:
|
|
22
|
+
"""
|
|
23
|
+
dtype: jnp.dtype
|
|
24
|
+
hidden_size: int
|
|
25
|
+
num_experts: int
|
|
26
|
+
num_experts_per_tok: int
|
|
27
|
+
router_act: str
|
|
28
|
+
rngs: InitVar[nnx.Rngs]
|
|
29
|
+
activation_ffw_td: Sharding
|
|
30
|
+
ed_sharding: Sharding
|
|
31
|
+
random_init: bool = False
|
|
32
|
+
|
|
33
|
+
def __call__(self, x_TD: Float):
|
|
34
|
+
"""Routes tokens to experts.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
x_TD: Input array of shape (sequence_length, d_model).
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
A tuple containing:
|
|
41
|
+
- normalized_weights_TX: Normalized weights for selected experts, shape (sequence_length, num_experts_per_tok).
|
|
42
|
+
- selected_experts_TX: Indices of selected experts, shape (sequence_length, num_experts_per_tok).
|
|
43
|
+
"""
|
|
44
|
+
x_TD = jnp.asarray(x_TD, self.dtype)
|
|
45
|
+
x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
|
|
46
|
+
router_act = modeling_flax_utils.ACT2FN[self.router_act]
|
|
47
|
+
router_logits_TE = jnp.einsum('TD,DE -> TE', x_TD,
|
|
48
|
+
self.kernel_DE.value)
|
|
49
|
+
weights_TX, selected_experts_TX = jax.lax.top_k(
|
|
50
|
+
router_logits_TE, self.num_experts_per_tok)
|
|
51
|
+
if self.router_act != "sigmoid": # sigmoid does not accept axis argument.
|
|
52
|
+
normalized_weights_TX = router_act(weights_TX.astype(self.dtype),
|
|
53
|
+
axis=-1)
|
|
54
|
+
else:
|
|
55
|
+
normalized_weights_TX = router_act(weights_TX.astype(self.dtype))
|
|
56
|
+
return normalized_weights_TX, selected_experts_TX
|
|
57
|
+
|
|
58
|
+
def __post_init__(self, rngs: nnx.Rngs):
|
|
59
|
+
"""Generates the router kernel (weights) for routing."""
|
|
60
|
+
shape = (self.hidden_size, self.num_experts)
|
|
61
|
+
self.kernel_DE = create_param(rngs,
|
|
62
|
+
shape=shape,
|
|
63
|
+
dtype=self.dtype,
|
|
64
|
+
sharding=self.ed_sharding,
|
|
65
|
+
random_init=self.random_init)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass(kw_only=True)
|
|
69
|
+
class MoE(nnx.Module):
|
|
70
|
+
"""Mixture-of-Experts (MoE) Routed MLP Layer.
|
|
71
|
+
|
|
72
|
+
This module implements a MoE layer with a router and multiple expert MLPs.
|
|
73
|
+
|
|
74
|
+
Attributes:
|
|
75
|
+
router: The Router module.
|
|
76
|
+
"""
|
|
77
|
+
dtype: jnp.dtype
|
|
78
|
+
num_local_experts: int
|
|
79
|
+
apply_expert_weight_before_computation: bool
|
|
80
|
+
hidden_size: int
|
|
81
|
+
intermediate_size_moe: int
|
|
82
|
+
hidden_act: str
|
|
83
|
+
rngs: InitVar[nnx.Rngs]
|
|
84
|
+
router: nnx.Module
|
|
85
|
+
activation_ffw_td: Sharding
|
|
86
|
+
activation_ffw_ted: Sharding
|
|
87
|
+
edf_sharding: Sharding
|
|
88
|
+
efd_sharding: Sharding
|
|
89
|
+
random_init: bool = False
|
|
90
|
+
|
|
91
|
+
def __call__(self, x_TD: Float):
|
|
92
|
+
"""Performs the forward pass of the MoE layer.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
x_TD: Input array of shape (sequence_length, d_model).
|
|
96
|
+
|
|
97
|
+
Returns:
|
|
98
|
+
Output array of shape (sequence_length, d_model) after passing through MoE.
|
|
99
|
+
"""
|
|
100
|
+
x_TD = jnp.asarray(x_TD, self.dtype)
|
|
101
|
+
x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
|
|
102
|
+
weights_TX, indices_TX = self.router(x_TD)
|
|
103
|
+
one_hot_indices_TXE = jax.nn.one_hot(
|
|
104
|
+
indices_TX, num_classes=self.num_local_experts, dtype=self.dtype)
|
|
105
|
+
full_weights_TE = jnp.sum(one_hot_indices_TXE * weights_TX[..., None],
|
|
106
|
+
axis=1)
|
|
107
|
+
|
|
108
|
+
# Some models use the routing scores to weight the data instead of
|
|
109
|
+
# weighting the expert outputs.
|
|
110
|
+
if self.apply_expert_weight_before_computation:
|
|
111
|
+
with jax.named_scope("pre_computing_weight"):
|
|
112
|
+
return self._moe_fwd_preapply_router_weights(
|
|
113
|
+
x_TD, full_weights_TE)
|
|
114
|
+
else:
|
|
115
|
+
return self._moe_fwd(x_TD, full_weights_TE)
|
|
116
|
+
|
|
117
|
+
def __post_init__(self, rngs: nnx.Rngs):
|
|
118
|
+
"""Generates the kernels (weights) for the router and experts (gating, up-projection, and down-projection layers)."""
|
|
119
|
+
|
|
120
|
+
D = self.hidden_size
|
|
121
|
+
F = self.intermediate_size_moe
|
|
122
|
+
shape_gating = (self.num_local_experts, D, F)
|
|
123
|
+
shape_up = (self.num_local_experts, D, F)
|
|
124
|
+
shape_down = (self.num_local_experts, F, D)
|
|
125
|
+
|
|
126
|
+
self.kernel_gating_EDF = create_param(rngs,
|
|
127
|
+
shape=shape_gating,
|
|
128
|
+
dtype=self.dtype,
|
|
129
|
+
sharding=self.edf_sharding,
|
|
130
|
+
random_init=self.random_init)
|
|
131
|
+
self.kernel_up_proj_EDF = create_param(rngs,
|
|
132
|
+
shape=shape_up,
|
|
133
|
+
dtype=self.dtype,
|
|
134
|
+
sharding=self.edf_sharding,
|
|
135
|
+
random_init=self.random_init)
|
|
136
|
+
self.kernel_down_proj_EFD = create_param(rngs,
|
|
137
|
+
shape=shape_down,
|
|
138
|
+
dtype=self.dtype,
|
|
139
|
+
sharding=self.efd_sharding,
|
|
140
|
+
random_init=self.random_init)
|
|
141
|
+
|
|
142
|
+
def _moe_fwd_preapply_router_weights(self, x_TD: jax.Array, weights_TE):
|
|
143
|
+
"""Performs the forward pass of the MoE experts with router weights pre-applied to the inputs.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
x_TD: Input array for the experts, shape (sequence_length, hidden_size).
|
|
147
|
+
weights_TE: Router weights, shape (sequence_length, num_experts).
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
Output array of shape (sequence_length, d_model).
|
|
151
|
+
"""
|
|
152
|
+
# Data needs to be replicated since it will be weighted by the router
|
|
153
|
+
# scores before being passed to each expert.
|
|
154
|
+
num_experts = weights_TE.shape[-1]
|
|
155
|
+
x_TED = jnp.repeat(x_TD[:, None, :], num_experts, 1)
|
|
156
|
+
weights_TED = weights_TE[..., None]
|
|
157
|
+
x_TED = jnp.asarray(x_TED, self.dtype)
|
|
158
|
+
|
|
159
|
+
with jax.named_scope("activation_expert_weighting"):
|
|
160
|
+
x_TED = x_TED * weights_TED
|
|
161
|
+
|
|
162
|
+
x_TED = nnx.with_sharding_constraint(x_TED, self.activation_ffw_ted)
|
|
163
|
+
with jax.named_scope("gating"):
|
|
164
|
+
gating_TEF = jnp.einsum('TED,EDF -> TEF', x_TED,
|
|
165
|
+
self.kernel_gating_EDF.value)
|
|
166
|
+
activated_gating_TEF = modeling_flax_utils.ACT2FN[self.hidden_act](
|
|
167
|
+
gating_TEF)
|
|
168
|
+
with jax.named_scope("up_projection"):
|
|
169
|
+
up_proj_TEF = jnp.einsum('TED,EDF -> TEF', x_TED,
|
|
170
|
+
self.kernel_up_proj_EDF.value)
|
|
171
|
+
|
|
172
|
+
fuse_TEF = activated_gating_TEF * up_proj_TEF
|
|
173
|
+
|
|
174
|
+
with jax.named_scope("down_projection"):
|
|
175
|
+
down_proj_TED = jnp.einsum('TEF,EFD -> TED', fuse_TEF,
|
|
176
|
+
self.kernel_down_proj_EFD.value)
|
|
177
|
+
with jax.named_scope("sum"):
|
|
178
|
+
output_TD = down_proj_TED.sum(axis=1)
|
|
179
|
+
return output_TD.astype(self.dtype)
|
|
180
|
+
|
|
181
|
+
def _moe_fwd(self, x_TD: Float, weights):
|
|
182
|
+
"""Performs the basic forward pass of the MoE experts without dropping or megablocks.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
x_TD: Input array for the experts, shape (sequence_length, d_model).
|
|
186
|
+
weights: Weights for combining expert outputs, shape (sequence_length, num_experts).
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Output array of shape (sequence_length, d_model).
|
|
190
|
+
"""
|
|
191
|
+
x_TD = jnp.asarray(x_TD, self.dtype)
|
|
192
|
+
x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
|
|
193
|
+
with jax.named_scope("gating"):
|
|
194
|
+
gating_TEF = jnp.einsum('TD,EDF -> TEF', x_TD,
|
|
195
|
+
self.kernel_gating_EDF.value)
|
|
196
|
+
activated_gating_TEF = modeling_flax_utils.ACT2FN[self.hidden_act](
|
|
197
|
+
gating_TEF)
|
|
198
|
+
with jax.named_scope("up_projection"):
|
|
199
|
+
up_proj_TEF = jnp.einsum('TD,EDF -> TEF', x_TD,
|
|
200
|
+
self.kernel_up_proj_EDF.value)
|
|
201
|
+
|
|
202
|
+
fuse_TEF = activated_gating_TEF * up_proj_TEF
|
|
203
|
+
|
|
204
|
+
with jax.named_scope("down_projection"):
|
|
205
|
+
down_proj_TED = jnp.einsum('TEF,EFD -> TED', fuse_TEF,
|
|
206
|
+
self.kernel_down_proj_EFD.value)
|
|
207
|
+
with jax.named_scope("sum"):
|
|
208
|
+
output_TD = jnp.einsum('TED,TE -> TD', down_proj_TED, weights)
|
|
209
|
+
return output_TD.astype(self.dtype)
|
|
@@ -0,0 +1,280 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import jax
|
|
6
|
+
from flax import nnx
|
|
7
|
+
from jax import numpy as jnp
|
|
8
|
+
from jax.experimental.layout import Layout, with_layout_constraint
|
|
9
|
+
from jax.sharding import NamedSharding, PartitionSpec
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass(kw_only=True)
|
|
13
|
+
class RotaryEmbedding(nnx.Module):
|
|
14
|
+
"""
|
|
15
|
+
An implementation of the original rotary positional embedding.
|
|
16
|
+
"""
|
|
17
|
+
rotary_dim: int
|
|
18
|
+
rope_theta: float
|
|
19
|
+
original_max_position_embeddings: int
|
|
20
|
+
dtype: jnp.dtype
|
|
21
|
+
sin_cos_cache: Optional[jax.Array] = field(init=False, default=None)
|
|
22
|
+
|
|
23
|
+
def initialize_cache(self):
|
|
24
|
+
"""Computes and caches the sin/cos embeddings."""
|
|
25
|
+
if self.sin_cos_cache is None:
|
|
26
|
+
self.sin_cos_cache = self._compute_sin_cos()
|
|
27
|
+
|
|
28
|
+
def _compute_inv_freq(self):
|
|
29
|
+
fractions_H = jnp.arange(0, self.rotary_dim, 2,
|
|
30
|
+
dtype=jnp.float32) / self.rotary_dim
|
|
31
|
+
inv_freq_H = 1.0 / (self.rope_theta**fractions_H)
|
|
32
|
+
return inv_freq_H
|
|
33
|
+
|
|
34
|
+
def _compute_sin_cos(self):
|
|
35
|
+
inv_freq_H = self._compute_inv_freq()
|
|
36
|
+
t = jnp.arange(self.original_max_position_embeddings,
|
|
37
|
+
dtype=jnp.float32)
|
|
38
|
+
|
|
39
|
+
freqs = jnp.einsum("...T,k->...Tk",
|
|
40
|
+
t,
|
|
41
|
+
inv_freq_H,
|
|
42
|
+
precision=jax.lax.Precision.HIGHEST)
|
|
43
|
+
sin, cos = jnp.sin(freqs), jnp.cos(freqs)
|
|
44
|
+
cache = jnp.concatenate((cos, sin), axis=-1)
|
|
45
|
+
return cache
|
|
46
|
+
|
|
47
|
+
def apply_rope(self, positions: jax.Array, x_TNH: jax.Array):
|
|
48
|
+
assert x_TNH.ndim == 3
|
|
49
|
+
assert self.sin_cos_cache is not None, "RoPE cache not initialized."
|
|
50
|
+
cos_sin_TH = self.sin_cos_cache[positions]
|
|
51
|
+
# cos, sin: (T, H/2)
|
|
52
|
+
cos_TH, sin_TH = jnp.split(cos_sin_TH, 2, axis=-1)
|
|
53
|
+
assert sin_TH.ndim == 2 and cos_TH.ndim == 2
|
|
54
|
+
# cos, sin: (T, 1, H/2)
|
|
55
|
+
cos_T1H, sin_T1H = cos_TH[:, None, :], sin_TH[:, None, :]
|
|
56
|
+
# first_half, second_half: (T, N, H/2)
|
|
57
|
+
first_half_TNH, second_half_TNH = jnp.split(x_TNH, 2, axis=-1)
|
|
58
|
+
combined = jnp.concatenate([
|
|
59
|
+
first_half_TNH * cos_T1H - second_half_TNH * sin_T1H,
|
|
60
|
+
second_half_TNH * cos_T1H + first_half_TNH * sin_T1H
|
|
61
|
+
],
|
|
62
|
+
axis=-1)
|
|
63
|
+
return combined.astype(self.dtype)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass(kw_only=True)
|
|
67
|
+
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
68
|
+
"""
|
|
69
|
+
Rotary Embedding for deepseek, with scaling and YaRN method.
|
|
70
|
+
"""
|
|
71
|
+
scaling_factor: float
|
|
72
|
+
beta_fast: int = 32
|
|
73
|
+
beta_slow: int = 1
|
|
74
|
+
mscale_value: float = 1
|
|
75
|
+
mscale_all_dim: float = 0
|
|
76
|
+
|
|
77
|
+
def initialize_cache(self, mesh: jax.sharding.Mesh):
|
|
78
|
+
"""Computes and caches the sin/cos embeddings."""
|
|
79
|
+
# The second condition is for the Qwix case, where we need to call `initialize_cache` on
|
|
80
|
+
# the abstract model. Thus, when we go to call `initialize_cache` on the concrete model,
|
|
81
|
+
# this method will have been called already, but we need to recompute the cache so that
|
|
82
|
+
# it's concrete (otherwise, it'll still be a jax.ShapeDtypeStruct).
|
|
83
|
+
if self.sin_cos_cache is not None and not isinstance(
|
|
84
|
+
self.sin_cos_cache, jax.ShapeDtypeStruct):
|
|
85
|
+
return
|
|
86
|
+
mscale_val = _yarn_get_mscale(
|
|
87
|
+
self.scaling_factor, self.mscale_value) / _yarn_get_mscale(
|
|
88
|
+
self.scaling_factor, self.mscale_all_dim)
|
|
89
|
+
replicated_sharding = NamedSharding(mesh, PartitionSpec())
|
|
90
|
+
self.mscale = jax.device_put(mscale_val, replicated_sharding)
|
|
91
|
+
self.sin_cos_cache = self._compute_sin_cos()
|
|
92
|
+
|
|
93
|
+
def _compute_inv_freq(self):
|
|
94
|
+
fractions = jnp.arange(0, self.rotary_dim, 2,
|
|
95
|
+
dtype=jnp.float32) / self.rotary_dim
|
|
96
|
+
inv_freq_extrapolation = 1.0 / (self.rope_theta**fractions)
|
|
97
|
+
inv_freq_interpolation = 1.0 / (self.scaling_factor *
|
|
98
|
+
self.rope_theta**fractions)
|
|
99
|
+
low, high = _yarn_find_correction_range(
|
|
100
|
+
self.beta_fast, self.beta_slow, self.rotary_dim, self.rope_theta,
|
|
101
|
+
self.original_max_position_embeddings)
|
|
102
|
+
|
|
103
|
+
# Get n-d rotational scaling corrected for extrapolation
|
|
104
|
+
inv_freq_mask = 1 - _yarn_linear_ramp_mask(
|
|
105
|
+
low, high, self.rotary_dim // 2).astype(jnp.float32)
|
|
106
|
+
inv_freq = inv_freq_interpolation * (
|
|
107
|
+
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
|
108
|
+
return inv_freq
|
|
109
|
+
|
|
110
|
+
@jax.jit
|
|
111
|
+
def _compute_sin_cos(self):
|
|
112
|
+
inv_freq_H = self._compute_inv_freq()
|
|
113
|
+
t = jnp.arange(self.original_max_position_embeddings *
|
|
114
|
+
self.scaling_factor,
|
|
115
|
+
dtype=jnp.float32)
|
|
116
|
+
freqs = jnp.einsum("...T,k->...Tk", t, inv_freq_H)
|
|
117
|
+
sin, cos = jnp.sin(freqs) * self.mscale, jnp.cos(freqs) * self.mscale
|
|
118
|
+
cache = jnp.concatenate((cos, sin), axis=-1)
|
|
119
|
+
H = cache.shape[1]
|
|
120
|
+
target_dim = ((H - 1) // 128 + 1) * 128
|
|
121
|
+
padding_amount = target_dim - self.rotary_dim
|
|
122
|
+
pad_width = ((0, 0), (0, padding_amount))
|
|
123
|
+
cache_padded = jnp.pad(cache, pad_width, mode='constant')
|
|
124
|
+
desired_layout = Layout(major_to_minor=(1, 0))
|
|
125
|
+
cache_padded = with_layout_constraint(cache_padded, desired_layout)
|
|
126
|
+
return cache_padded
|
|
127
|
+
|
|
128
|
+
def apply_rope(self, positions: jax.Array, x_TNH: jax.Array):
|
|
129
|
+
assert x_TNH.ndim == 3
|
|
130
|
+
assert self.sin_cos_cache is not None, "RoPE cache not initialized."
|
|
131
|
+
cos_sin_padded = self.sin_cos_cache[positions]
|
|
132
|
+
cos_sin_TH = cos_sin_padded[:, :self.rotary_dim]
|
|
133
|
+
# cos, sin: (T, H/2)
|
|
134
|
+
cos_TH, sin_TH = jnp.split(cos_sin_TH, 2, axis=-1)
|
|
135
|
+
assert sin_TH.ndim == 2 and cos_TH.ndim == 2
|
|
136
|
+
# cos, sin: (T, 1, H/2)
|
|
137
|
+
cos_T1H, sin_T1H = cos_TH[:, None, :], sin_TH[:, None, :]
|
|
138
|
+
# even, odd: (T, N, H/2)
|
|
139
|
+
even_TNH, odd_TNH = x_TNH[..., ::2], x_TNH[..., 1::2]
|
|
140
|
+
combined_TNH = jnp.stack([
|
|
141
|
+
even_TNH * cos_T1H - odd_TNH * sin_T1H,
|
|
142
|
+
odd_TNH * cos_T1H + even_TNH * sin_T1H
|
|
143
|
+
],
|
|
144
|
+
axis=-1).reshape(x_TNH.shape)
|
|
145
|
+
return combined_TNH.astype(self.dtype)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
# Calculates the temperature scaling factor for YaRN to adjust
|
|
149
|
+
# RoPE embedding magnitudes.
|
|
150
|
+
def _yarn_get_mscale(scale, mscale):
|
|
151
|
+
return jnp.where(scale <= 1, 1.0, 0.1 * mscale * jnp.log(scale) + 1.0)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
# Inverses dim formula to find dim based on number of rotations.
|
|
155
|
+
def _yarn_find_correction_dim(num_rotations,
|
|
156
|
+
dim,
|
|
157
|
+
base=10000,
|
|
158
|
+
max_position_embeddings=2048):
|
|
159
|
+
return (dim * math.log(max_position_embeddings /
|
|
160
|
+
(num_rotations * 2 * math.pi))) / (2 *
|
|
161
|
+
math.log(base))
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
# Finds dim range bounds based on rotations.
|
|
165
|
+
def _yarn_find_correction_range(low_rot,
|
|
166
|
+
high_rot,
|
|
167
|
+
dim,
|
|
168
|
+
base=10000,
|
|
169
|
+
max_position_embeddings=2048):
|
|
170
|
+
low = math.floor(
|
|
171
|
+
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
|
|
172
|
+
high = math.ceil(
|
|
173
|
+
_yarn_find_correction_dim(high_rot, dim, base,
|
|
174
|
+
max_position_embeddings))
|
|
175
|
+
return max(low, 0), min(high, dim - 1) # Clamp values just in case
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
# Creates a 1D mask that ramps linearly from 0 to 1 between min and max indices.
|
|
179
|
+
def _yarn_linear_ramp_mask(min, max, dim):
|
|
180
|
+
if min == max:
|
|
181
|
+
max += 0.001 # Prevent singularity
|
|
182
|
+
|
|
183
|
+
linear_func = (jnp.arange(dim, dtype=jnp.float32) - min) / (max - min)
|
|
184
|
+
ramp_func = jnp.clip(linear_func, 0, 1)
|
|
185
|
+
return ramp_func
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
@dataclass(kw_only=True)
|
|
189
|
+
class GptOssRotaryEmbedding(nnx.Module):
|
|
190
|
+
"""
|
|
191
|
+
JAX implementation of the Rotary Positional Embedding with YaRN scaling.
|
|
192
|
+
"""
|
|
193
|
+
head_dim: int
|
|
194
|
+
rope_theta: float
|
|
195
|
+
dtype: jnp.dtype
|
|
196
|
+
initial_context_length: int = 4096
|
|
197
|
+
rope_scaling_factor: float = 1.0
|
|
198
|
+
rope_ntk_alpha: float = 1.0
|
|
199
|
+
rope_ntk_beta: float = 32.0
|
|
200
|
+
|
|
201
|
+
def _compute_concentration_and_inv_freq(self) -> Tuple[float, jax.Array]:
|
|
202
|
+
"""
|
|
203
|
+
Computes the inverse frequencies and concentration factor for YaRN.
|
|
204
|
+
See YaRN paper: https://arxiv.org/abs/2309.00071
|
|
205
|
+
"""
|
|
206
|
+
freq = self.rope_theta**(
|
|
207
|
+
jnp.arange(0, self.head_dim, 2, dtype=jnp.float32) / self.head_dim)
|
|
208
|
+
|
|
209
|
+
if self.rope_scaling_factor > 1.0:
|
|
210
|
+
concentration = 0.1 * jnp.log(self.rope_scaling_factor) + 1.0
|
|
211
|
+
|
|
212
|
+
d_half = self.head_dim / 2
|
|
213
|
+
# NTK by parts
|
|
214
|
+
low = (d_half * jnp.log(self.initial_context_length /
|
|
215
|
+
(self.rope_ntk_beta * 2 * jnp.pi)) /
|
|
216
|
+
jnp.log(self.rope_theta))
|
|
217
|
+
high = (d_half * jnp.log(self.initial_context_length /
|
|
218
|
+
(self.rope_ntk_alpha * 2 * jnp.pi)) /
|
|
219
|
+
jnp.log(self.rope_theta))
|
|
220
|
+
|
|
221
|
+
interpolation = 1.0 / (self.rope_scaling_factor * freq)
|
|
222
|
+
extrapolation = 1.0 / freq
|
|
223
|
+
|
|
224
|
+
ramp = (jnp.arange(d_half, dtype=jnp.float32) - low) / (high - low)
|
|
225
|
+
mask = 1 - jnp.clip(ramp, 0, 1)
|
|
226
|
+
|
|
227
|
+
inv_freq = interpolation * (1 - mask) + extrapolation * mask
|
|
228
|
+
else:
|
|
229
|
+
concentration = 1.0
|
|
230
|
+
inv_freq = 1.0 / freq
|
|
231
|
+
|
|
232
|
+
return concentration, inv_freq
|
|
233
|
+
|
|
234
|
+
def _compute_cos_sin(self,
|
|
235
|
+
positions: jax.Array) -> Tuple[jax.Array, jax.Array]:
|
|
236
|
+
"""Computes cosine and sine embeddings for given positions."""
|
|
237
|
+
concentration, inv_freq_H = self._compute_concentration_and_inv_freq()
|
|
238
|
+
|
|
239
|
+
# freqs: (T, H/2)
|
|
240
|
+
freqs = jnp.einsum("T,H->TH",
|
|
241
|
+
positions.astype(jnp.float32),
|
|
242
|
+
inv_freq_H,
|
|
243
|
+
precision=jax.lax.Precision.HIGHEST)
|
|
244
|
+
|
|
245
|
+
cos = jnp.cos(freqs) * concentration
|
|
246
|
+
sin = jnp.sin(freqs) * concentration
|
|
247
|
+
return cos, sin
|
|
248
|
+
|
|
249
|
+
def __call__(self, query_TNH: jax.Array, key_TNH: jax.Array,
|
|
250
|
+
positions: jax.Array) -> Tuple[jax.Array, jax.Array]:
|
|
251
|
+
"""
|
|
252
|
+
Applies rotary embeddings to query and key tensors.
|
|
253
|
+
Args:
|
|
254
|
+
query_TNH: Query tensor with shape (num_tokens, num_heads, head_dim)
|
|
255
|
+
key_TNH: Key tensor with shape (num_tokens, num_kv_heads, head_dim)
|
|
256
|
+
positions: A 1D array of token positions.
|
|
257
|
+
"""
|
|
258
|
+
# cos, sin: (T, H/2)
|
|
259
|
+
cos_TH, sin_TH = self._compute_cos_sin(positions)
|
|
260
|
+
|
|
261
|
+
# Reshape for broadcasting: (T, 1, H/2)
|
|
262
|
+
cos_T1H = cos_TH[:, None, :]
|
|
263
|
+
sin_T1H = sin_TH[:, None, :]
|
|
264
|
+
|
|
265
|
+
def _apply_rotation(x_TNH: jax.Array) -> jax.Array:
|
|
266
|
+
# Split the last dimension
|
|
267
|
+
first_half, second_half = jnp.split(x_TNH, 2, axis=-1)
|
|
268
|
+
|
|
269
|
+
# Apply rotation
|
|
270
|
+
rotated_x = jnp.concatenate([
|
|
271
|
+
first_half * cos_T1H - second_half * sin_T1H,
|
|
272
|
+
second_half * cos_T1H + first_half * sin_T1H
|
|
273
|
+
],
|
|
274
|
+
axis=-1)
|
|
275
|
+
return rotated_x.astype(self.dtype)
|
|
276
|
+
|
|
277
|
+
rotated_query = _apply_rotation(query_TNH)
|
|
278
|
+
rotated_key = _apply_rotation(key_TNH)
|
|
279
|
+
|
|
280
|
+
return rotated_query, rotated_key
|