tpu-inference 0.11.1.dev202511150811__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (179) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +105 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +654 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +96 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +182 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +236 -0
  27. tpu_inference/__init__.py +34 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +51 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +728 -0
  37. tpu_inference/distributed/utils.py +59 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +107 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +362 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1035 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1478 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1482 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +390 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +507 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +105 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +386 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +311 -0
  120. tpu_inference/mock/__init__.py +0 -0
  121. tpu_inference/mock/vllm_config_utils.py +28 -0
  122. tpu_inference/mock/vllm_envs.py +1219 -0
  123. tpu_inference/mock/vllm_logger.py +212 -0
  124. tpu_inference/mock/vllm_logging_utils.py +15 -0
  125. tpu_inference/models/__init__.py +0 -0
  126. tpu_inference/models/common/__init__.py +0 -0
  127. tpu_inference/models/common/model_loader.py +444 -0
  128. tpu_inference/models/jax/__init__.py +0 -0
  129. tpu_inference/models/jax/deepseek_v3.py +868 -0
  130. tpu_inference/models/jax/gpt_oss.py +492 -0
  131. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  132. tpu_inference/models/jax/llama3.py +375 -0
  133. tpu_inference/models/jax/llama4.py +629 -0
  134. tpu_inference/models/jax/llama_eagle3.py +333 -0
  135. tpu_inference/models/jax/phi3.py +376 -0
  136. tpu_inference/models/jax/qwen2.py +375 -0
  137. tpu_inference/models/jax/qwen2_5_vl.py +1103 -0
  138. tpu_inference/models/jax/qwen3.py +302 -0
  139. tpu_inference/models/jax/utils/__init__.py +0 -0
  140. tpu_inference/models/jax/utils/file_utils.py +96 -0
  141. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  142. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  143. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  144. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  145. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  146. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  147. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  148. tpu_inference/models/jax/utils/quantization/quantization_utils.py +653 -0
  149. tpu_inference/models/jax/utils/weight_utils.py +529 -0
  150. tpu_inference/models/vllm/__init__.py +0 -0
  151. tpu_inference/models/vllm/vllm_model_wrapper.py +286 -0
  152. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  153. tpu_inference/platforms/__init__.py +2 -0
  154. tpu_inference/platforms/tpu_platform.py +269 -0
  155. tpu_inference/runner/__init__.py +0 -0
  156. tpu_inference/runner/block_table.py +122 -0
  157. tpu_inference/runner/compilation_manager.py +780 -0
  158. tpu_inference/runner/input_batch.py +435 -0
  159. tpu_inference/runner/kv_cache.py +132 -0
  160. tpu_inference/runner/kv_cache_manager.py +479 -0
  161. tpu_inference/runner/lora_utils.py +92 -0
  162. tpu_inference/runner/multimodal_manager.py +217 -0
  163. tpu_inference/runner/persistent_batch_manager.py +244 -0
  164. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  165. tpu_inference/runner/structured_decoding_manager.py +88 -0
  166. tpu_inference/runner/tpu_runner.py +1620 -0
  167. tpu_inference/runner/utils.py +426 -0
  168. tpu_inference/spec_decode/__init__.py +0 -0
  169. tpu_inference/spec_decode/jax/__init__.py +0 -0
  170. tpu_inference/spec_decode/jax/eagle3.py +367 -0
  171. tpu_inference/tpu_info.py +77 -0
  172. tpu_inference/utils.py +317 -0
  173. tpu_inference/worker/__init__.py +0 -0
  174. tpu_inference/worker/tpu_worker.py +321 -0
  175. tpu_inference-0.11.1.dev202511150811.dist-info/METADATA +107 -0
  176. tpu_inference-0.11.1.dev202511150811.dist-info/RECORD +179 -0
  177. tpu_inference-0.11.1.dev202511150811.dist-info/WHEEL +5 -0
  178. tpu_inference-0.11.1.dev202511150811.dist-info/licenses/LICENSE +201 -0
  179. tpu_inference-0.11.1.dev202511150811.dist-info/top_level.txt +2 -0
@@ -0,0 +1,185 @@
1
+ from dataclasses import InitVar, dataclass
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from flax import nnx
6
+ from flax.typing import Sharding
7
+ from jaxtyping import Float
8
+
9
+ from tpu_inference.layers.jax.base import create_param
10
+ from tpu_inference.layers.jax.layers import FlaxUtils
11
+ from tpu_inference.layers.jax.moe.moe import Router
12
+
13
+ modeling_flax_utils = FlaxUtils()
14
+
15
+
16
+ @dataclass(kw_only=True)
17
+ class GptOssRouter(Router):
18
+ """Router module for Mixture-of-Experts (MoE) layers.
19
+
20
+ This module determines which experts each token should be routed.
21
+
22
+ """
23
+ e_sharding: Sharding = ()
24
+
25
+ def __post_init__(self, rngs: nnx.Rngs):
26
+ """
27
+ Initializes the parent's kernel and adds the new bias parameter.
28
+ """
29
+ super().__post_init__(rngs)
30
+
31
+ self.bias_E = create_param(rngs,
32
+ shape=(self.num_experts, ),
33
+ dtype=self.dtype,
34
+ sharding=self.e_sharding,
35
+ random_init=self.random_init)
36
+
37
+ def __call__(self, x_TD: Float):
38
+ """
39
+ Overrides the parent's forward pass to include the bias.
40
+ """
41
+ x_TD = jnp.asarray(x_TD, self.dtype)
42
+ x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
43
+
44
+ router_logits_TE = jnp.einsum('TD,DE -> TE', x_TD,
45
+ self.kernel_DE.value)
46
+
47
+ router_logits_TE += self.bias_E.value
48
+
49
+ weights_TX, selected_experts_TX = jax.lax.top_k(
50
+ router_logits_TE, self.num_experts_per_tok)
51
+
52
+ normalized_weights_TX = jax.nn.softmax(weights_TX.astype(self.dtype),
53
+ axis=-1)
54
+
55
+ return normalized_weights_TX, selected_experts_TX
56
+
57
+
58
+ def _swiglu(x: Float, alpha: Float, limit: Float) -> Float:
59
+ """Implements the specific SwiGLU from the golden implementation."""
60
+ x_glu, x_linear = x[..., ::2], x[..., 1::2]
61
+
62
+ x_glu = jnp.clip(x_glu, a_max=limit)
63
+ x_linear = jnp.clip(x_linear, a_min=-limit, a_max=limit)
64
+
65
+ gated_activation = x_glu * jax.nn.sigmoid(alpha * x_glu)
66
+
67
+ return gated_activation * (x_linear + 1)
68
+
69
+
70
+ @dataclass(kw_only=True)
71
+ class CombineExperts(nnx.Module):
72
+ """Module for combining expert outputs with weighted sum."""
73
+ dtype: jnp.dtype
74
+
75
+ def __call__(self, down_proj_TED: Float, weights_TX: Float,
76
+ indices_TX: jax.Array) -> Float:
77
+ """Combines expert outputs using weighted sum.
78
+
79
+ Args:
80
+ down_proj_TED: Expert outputs, shape (tokens, experts, hidden_dim)
81
+ weights_TX: Router weights, shape (tokens, experts_per_token)
82
+ indices_TX: Selected expert indices, shape (tokens, experts_per_token)
83
+
84
+ Returns:
85
+ Combined output, shape (tokens, hidden_dim)
86
+ """
87
+ with jax.named_scope("combine_experts"):
88
+ indices_for_gather = indices_TX[..., None]
89
+ gathered_down_proj_TED = jnp.take_along_axis(down_proj_TED,
90
+ indices_for_gather,
91
+ axis=1)
92
+ output_TD = jnp.einsum('TXD,TX -> TD', gathered_down_proj_TED,
93
+ weights_TX)
94
+
95
+ return output_TD.astype(self.dtype)
96
+
97
+
98
+ @dataclass(kw_only=True)
99
+ class GptOssMoE(nnx.Module):
100
+ """
101
+ JAX implementation of the GPT-OSS Mixture-of-Experts MLP block.
102
+ """
103
+ dtype: jnp.dtype
104
+ hidden_size: int
105
+ intermediate_size_moe: int
106
+ num_local_experts: int
107
+ router: GptOssRouter
108
+ rngs: InitVar[nnx.Rngs]
109
+
110
+ swiglu_limit: float = 7.0
111
+ swiglu_alpha: float = 1.702
112
+
113
+ # Sharding specifications
114
+ activation_ffw_td: Sharding
115
+ edf_sharding: Sharding
116
+ efd_sharding: Sharding
117
+ ed_sharding: Sharding
118
+
119
+ random_init: bool = False
120
+
121
+ def __call__(self, x_TD: Float) -> Float:
122
+ """Performs the forward pass for the GPT-OSS MoE layer."""
123
+ x_TD = jnp.asarray(x_TD, self.dtype)
124
+ x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
125
+
126
+ weights_TX, indices_TX = self.router(x_TD)
127
+
128
+ # First MLP layer (up-projection)
129
+ with jax.named_scope("MLP #1"):
130
+ up_proj_TEF2 = jnp.einsum('TD,EDF -> TEF', x_TD,
131
+ self.mlp1_weight_EDF2.value)
132
+ up_proj_TEF2 += self.mlp1_bias_EF2.value
133
+
134
+ fuse_TEF = _swiglu(up_proj_TEF2,
135
+ alpha=self.swiglu_alpha,
136
+ limit=self.swiglu_limit)
137
+
138
+ # Second MLP layer (down-projection)
139
+ with jax.named_scope("MLP #2"):
140
+ down_proj_TED = jnp.einsum('TEF,EFD -> TED', fuse_TEF,
141
+ self.mlp2_weight_EFD.value)
142
+ down_proj_TED += self.mlp2_bias_ED.value
143
+
144
+ # Weighted sum of expert outputs
145
+ output_TD = self.combine_experts(down_proj_TED, weights_TX, indices_TX)
146
+
147
+ return output_TD
148
+
149
+ def __post_init__(self, rngs: nnx.Rngs):
150
+ """Initializes all weights and biases for the MoE block."""
151
+ D, F, E = self.hidden_size, self.intermediate_size_moe, self.num_local_experts
152
+
153
+ self.combine_experts = CombineExperts(dtype=self.dtype)
154
+
155
+ # MLP #1 Weights (Combined Gate and Up-projection) and Bias
156
+ self.mlp1_weight_EDF2 = create_param(
157
+ rngs,
158
+ shape=(E, D, F * 2),
159
+ dtype=self.dtype,
160
+ sharding=self.edf_sharding,
161
+ random_init=self.random_init,
162
+ )
163
+ self.mlp1_bias_EF2 = create_param(
164
+ rngs,
165
+ shape=(E, F * 2),
166
+ dtype=self.dtype,
167
+ sharding=self.ed_sharding,
168
+ random_init=self.random_init,
169
+ )
170
+
171
+ # MLP #2 Weights (Down-projection) and Bias
172
+ self.mlp2_weight_EFD = create_param(
173
+ rngs,
174
+ shape=(E, F, D),
175
+ dtype=self.dtype,
176
+ sharding=self.efd_sharding,
177
+ random_init=self.random_init,
178
+ )
179
+ self.mlp2_bias_ED = create_param(
180
+ rngs,
181
+ shape=(E, D),
182
+ dtype=self.dtype,
183
+ sharding=self.ed_sharding,
184
+ random_init=self.random_init,
185
+ )
@@ -0,0 +1,209 @@
1
+ from dataclasses import InitVar, dataclass
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from flax import nnx
6
+ from flax.typing import Sharding
7
+ from jaxtyping import Float
8
+
9
+ from tpu_inference.layers.jax.base import create_param
10
+ from tpu_inference.layers.jax.layers import FlaxUtils
11
+
12
+ modeling_flax_utils = FlaxUtils()
13
+
14
+
15
+ @dataclass(kw_only=True)
16
+ class Router(nnx.Module):
17
+ """Router module for Mixture-of-Experts (MoE) layers.
18
+
19
+ This module determines which experts each token should be routed to based on the input.
20
+
21
+ Attributes:
22
+ """
23
+ dtype: jnp.dtype
24
+ hidden_size: int
25
+ num_experts: int
26
+ num_experts_per_tok: int
27
+ router_act: str
28
+ rngs: InitVar[nnx.Rngs]
29
+ activation_ffw_td: Sharding
30
+ ed_sharding: Sharding
31
+ random_init: bool = False
32
+
33
+ def __call__(self, x_TD: Float):
34
+ """Routes tokens to experts.
35
+
36
+ Args:
37
+ x_TD: Input array of shape (sequence_length, d_model).
38
+
39
+ Returns:
40
+ A tuple containing:
41
+ - normalized_weights_TX: Normalized weights for selected experts, shape (sequence_length, num_experts_per_tok).
42
+ - selected_experts_TX: Indices of selected experts, shape (sequence_length, num_experts_per_tok).
43
+ """
44
+ x_TD = jnp.asarray(x_TD, self.dtype)
45
+ x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
46
+ router_act = modeling_flax_utils.ACT2FN[self.router_act]
47
+ router_logits_TE = jnp.einsum('TD,DE -> TE', x_TD,
48
+ self.kernel_DE.value)
49
+ weights_TX, selected_experts_TX = jax.lax.top_k(
50
+ router_logits_TE, self.num_experts_per_tok)
51
+ if self.router_act != "sigmoid": # sigmoid does not accept axis argument.
52
+ normalized_weights_TX = router_act(weights_TX.astype(self.dtype),
53
+ axis=-1)
54
+ else:
55
+ normalized_weights_TX = router_act(weights_TX.astype(self.dtype))
56
+ return normalized_weights_TX, selected_experts_TX
57
+
58
+ def __post_init__(self, rngs: nnx.Rngs):
59
+ """Generates the router kernel (weights) for routing."""
60
+ shape = (self.hidden_size, self.num_experts)
61
+ self.kernel_DE = create_param(rngs,
62
+ shape=shape,
63
+ dtype=self.dtype,
64
+ sharding=self.ed_sharding,
65
+ random_init=self.random_init)
66
+
67
+
68
+ @dataclass(kw_only=True)
69
+ class MoE(nnx.Module):
70
+ """Mixture-of-Experts (MoE) Routed MLP Layer.
71
+
72
+ This module implements a MoE layer with a router and multiple expert MLPs.
73
+
74
+ Attributes:
75
+ router: The Router module.
76
+ """
77
+ dtype: jnp.dtype
78
+ num_local_experts: int
79
+ apply_expert_weight_before_computation: bool
80
+ hidden_size: int
81
+ intermediate_size_moe: int
82
+ hidden_act: str
83
+ rngs: InitVar[nnx.Rngs]
84
+ router: nnx.Module
85
+ activation_ffw_td: Sharding
86
+ activation_ffw_ted: Sharding
87
+ edf_sharding: Sharding
88
+ efd_sharding: Sharding
89
+ random_init: bool = False
90
+
91
+ def __call__(self, x_TD: Float):
92
+ """Performs the forward pass of the MoE layer.
93
+
94
+ Args:
95
+ x_TD: Input array of shape (sequence_length, d_model).
96
+
97
+ Returns:
98
+ Output array of shape (sequence_length, d_model) after passing through MoE.
99
+ """
100
+ x_TD = jnp.asarray(x_TD, self.dtype)
101
+ x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
102
+ weights_TX, indices_TX = self.router(x_TD)
103
+ one_hot_indices_TXE = jax.nn.one_hot(
104
+ indices_TX, num_classes=self.num_local_experts, dtype=self.dtype)
105
+ full_weights_TE = jnp.sum(one_hot_indices_TXE * weights_TX[..., None],
106
+ axis=1)
107
+
108
+ # Some models use the routing scores to weight the data instead of
109
+ # weighting the expert outputs.
110
+ if self.apply_expert_weight_before_computation:
111
+ with jax.named_scope("pre_computing_weight"):
112
+ return self._moe_fwd_preapply_router_weights(
113
+ x_TD, full_weights_TE)
114
+ else:
115
+ return self._moe_fwd(x_TD, full_weights_TE)
116
+
117
+ def __post_init__(self, rngs: nnx.Rngs):
118
+ """Generates the kernels (weights) for the router and experts (gating, up-projection, and down-projection layers)."""
119
+
120
+ D = self.hidden_size
121
+ F = self.intermediate_size_moe
122
+ shape_gating = (self.num_local_experts, D, F)
123
+ shape_up = (self.num_local_experts, D, F)
124
+ shape_down = (self.num_local_experts, F, D)
125
+
126
+ self.kernel_gating_EDF = create_param(rngs,
127
+ shape=shape_gating,
128
+ dtype=self.dtype,
129
+ sharding=self.edf_sharding,
130
+ random_init=self.random_init)
131
+ self.kernel_up_proj_EDF = create_param(rngs,
132
+ shape=shape_up,
133
+ dtype=self.dtype,
134
+ sharding=self.edf_sharding,
135
+ random_init=self.random_init)
136
+ self.kernel_down_proj_EFD = create_param(rngs,
137
+ shape=shape_down,
138
+ dtype=self.dtype,
139
+ sharding=self.efd_sharding,
140
+ random_init=self.random_init)
141
+
142
+ def _moe_fwd_preapply_router_weights(self, x_TD: jax.Array, weights_TE):
143
+ """Performs the forward pass of the MoE experts with router weights pre-applied to the inputs.
144
+
145
+ Args:
146
+ x_TD: Input array for the experts, shape (sequence_length, hidden_size).
147
+ weights_TE: Router weights, shape (sequence_length, num_experts).
148
+
149
+ Returns:
150
+ Output array of shape (sequence_length, d_model).
151
+ """
152
+ # Data needs to be replicated since it will be weighted by the router
153
+ # scores before being passed to each expert.
154
+ num_experts = weights_TE.shape[-1]
155
+ x_TED = jnp.repeat(x_TD[:, None, :], num_experts, 1)
156
+ weights_TED = weights_TE[..., None]
157
+ x_TED = jnp.asarray(x_TED, self.dtype)
158
+
159
+ with jax.named_scope("activation_expert_weighting"):
160
+ x_TED = x_TED * weights_TED
161
+
162
+ x_TED = nnx.with_sharding_constraint(x_TED, self.activation_ffw_ted)
163
+ with jax.named_scope("gating"):
164
+ gating_TEF = jnp.einsum('TED,EDF -> TEF', x_TED,
165
+ self.kernel_gating_EDF.value)
166
+ activated_gating_TEF = modeling_flax_utils.ACT2FN[self.hidden_act](
167
+ gating_TEF)
168
+ with jax.named_scope("up_projection"):
169
+ up_proj_TEF = jnp.einsum('TED,EDF -> TEF', x_TED,
170
+ self.kernel_up_proj_EDF.value)
171
+
172
+ fuse_TEF = activated_gating_TEF * up_proj_TEF
173
+
174
+ with jax.named_scope("down_projection"):
175
+ down_proj_TED = jnp.einsum('TEF,EFD -> TED', fuse_TEF,
176
+ self.kernel_down_proj_EFD.value)
177
+ with jax.named_scope("sum"):
178
+ output_TD = down_proj_TED.sum(axis=1)
179
+ return output_TD.astype(self.dtype)
180
+
181
+ def _moe_fwd(self, x_TD: Float, weights):
182
+ """Performs the basic forward pass of the MoE experts without dropping or megablocks.
183
+
184
+ Args:
185
+ x_TD: Input array for the experts, shape (sequence_length, d_model).
186
+ weights: Weights for combining expert outputs, shape (sequence_length, num_experts).
187
+
188
+ Returns:
189
+ Output array of shape (sequence_length, d_model).
190
+ """
191
+ x_TD = jnp.asarray(x_TD, self.dtype)
192
+ x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
193
+ with jax.named_scope("gating"):
194
+ gating_TEF = jnp.einsum('TD,EDF -> TEF', x_TD,
195
+ self.kernel_gating_EDF.value)
196
+ activated_gating_TEF = modeling_flax_utils.ACT2FN[self.hidden_act](
197
+ gating_TEF)
198
+ with jax.named_scope("up_projection"):
199
+ up_proj_TEF = jnp.einsum('TD,EDF -> TEF', x_TD,
200
+ self.kernel_up_proj_EDF.value)
201
+
202
+ fuse_TEF = activated_gating_TEF * up_proj_TEF
203
+
204
+ with jax.named_scope("down_projection"):
205
+ down_proj_TED = jnp.einsum('TEF,EFD -> TED', fuse_TEF,
206
+ self.kernel_down_proj_EFD.value)
207
+ with jax.named_scope("sum"):
208
+ output_TD = jnp.einsum('TED,TE -> TD', down_proj_TED, weights)
209
+ return output_TD.astype(self.dtype)
@@ -0,0 +1,280 @@
1
+ import math
2
+ from dataclasses import dataclass, field
3
+ from typing import Optional, Tuple
4
+
5
+ import jax
6
+ from flax import nnx
7
+ from jax import numpy as jnp
8
+ from jax.experimental.layout import Layout, with_layout_constraint
9
+ from jax.sharding import NamedSharding, PartitionSpec
10
+
11
+
12
+ @dataclass(kw_only=True)
13
+ class RotaryEmbedding(nnx.Module):
14
+ """
15
+ An implementation of the original rotary positional embedding.
16
+ """
17
+ rotary_dim: int
18
+ rope_theta: float
19
+ original_max_position_embeddings: int
20
+ dtype: jnp.dtype
21
+ sin_cos_cache: Optional[jax.Array] = field(init=False, default=None)
22
+
23
+ def initialize_cache(self):
24
+ """Computes and caches the sin/cos embeddings."""
25
+ if self.sin_cos_cache is None:
26
+ self.sin_cos_cache = self._compute_sin_cos()
27
+
28
+ def _compute_inv_freq(self):
29
+ fractions_H = jnp.arange(0, self.rotary_dim, 2,
30
+ dtype=jnp.float32) / self.rotary_dim
31
+ inv_freq_H = 1.0 / (self.rope_theta**fractions_H)
32
+ return inv_freq_H
33
+
34
+ def _compute_sin_cos(self):
35
+ inv_freq_H = self._compute_inv_freq()
36
+ t = jnp.arange(self.original_max_position_embeddings,
37
+ dtype=jnp.float32)
38
+
39
+ freqs = jnp.einsum("...T,k->...Tk",
40
+ t,
41
+ inv_freq_H,
42
+ precision=jax.lax.Precision.HIGHEST)
43
+ sin, cos = jnp.sin(freqs), jnp.cos(freqs)
44
+ cache = jnp.concatenate((cos, sin), axis=-1)
45
+ return cache
46
+
47
+ def apply_rope(self, positions: jax.Array, x_TNH: jax.Array):
48
+ assert x_TNH.ndim == 3
49
+ assert self.sin_cos_cache is not None, "RoPE cache not initialized."
50
+ cos_sin_TH = self.sin_cos_cache[positions]
51
+ # cos, sin: (T, H/2)
52
+ cos_TH, sin_TH = jnp.split(cos_sin_TH, 2, axis=-1)
53
+ assert sin_TH.ndim == 2 and cos_TH.ndim == 2
54
+ # cos, sin: (T, 1, H/2)
55
+ cos_T1H, sin_T1H = cos_TH[:, None, :], sin_TH[:, None, :]
56
+ # first_half, second_half: (T, N, H/2)
57
+ first_half_TNH, second_half_TNH = jnp.split(x_TNH, 2, axis=-1)
58
+ combined = jnp.concatenate([
59
+ first_half_TNH * cos_T1H - second_half_TNH * sin_T1H,
60
+ second_half_TNH * cos_T1H + first_half_TNH * sin_T1H
61
+ ],
62
+ axis=-1)
63
+ return combined.astype(self.dtype)
64
+
65
+
66
+ @dataclass(kw_only=True)
67
+ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
68
+ """
69
+ Rotary Embedding for deepseek, with scaling and YaRN method.
70
+ """
71
+ scaling_factor: float
72
+ beta_fast: int = 32
73
+ beta_slow: int = 1
74
+ mscale_value: float = 1
75
+ mscale_all_dim: float = 0
76
+
77
+ def initialize_cache(self, mesh: jax.sharding.Mesh):
78
+ """Computes and caches the sin/cos embeddings."""
79
+ # The second condition is for the Qwix case, where we need to call `initialize_cache` on
80
+ # the abstract model. Thus, when we go to call `initialize_cache` on the concrete model,
81
+ # this method will have been called already, but we need to recompute the cache so that
82
+ # it's concrete (otherwise, it'll still be a jax.ShapeDtypeStruct).
83
+ if self.sin_cos_cache is not None and not isinstance(
84
+ self.sin_cos_cache, jax.ShapeDtypeStruct):
85
+ return
86
+ mscale_val = _yarn_get_mscale(
87
+ self.scaling_factor, self.mscale_value) / _yarn_get_mscale(
88
+ self.scaling_factor, self.mscale_all_dim)
89
+ replicated_sharding = NamedSharding(mesh, PartitionSpec())
90
+ self.mscale = jax.device_put(mscale_val, replicated_sharding)
91
+ self.sin_cos_cache = self._compute_sin_cos()
92
+
93
+ def _compute_inv_freq(self):
94
+ fractions = jnp.arange(0, self.rotary_dim, 2,
95
+ dtype=jnp.float32) / self.rotary_dim
96
+ inv_freq_extrapolation = 1.0 / (self.rope_theta**fractions)
97
+ inv_freq_interpolation = 1.0 / (self.scaling_factor *
98
+ self.rope_theta**fractions)
99
+ low, high = _yarn_find_correction_range(
100
+ self.beta_fast, self.beta_slow, self.rotary_dim, self.rope_theta,
101
+ self.original_max_position_embeddings)
102
+
103
+ # Get n-d rotational scaling corrected for extrapolation
104
+ inv_freq_mask = 1 - _yarn_linear_ramp_mask(
105
+ low, high, self.rotary_dim // 2).astype(jnp.float32)
106
+ inv_freq = inv_freq_interpolation * (
107
+ 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
108
+ return inv_freq
109
+
110
+ @jax.jit
111
+ def _compute_sin_cos(self):
112
+ inv_freq_H = self._compute_inv_freq()
113
+ t = jnp.arange(self.original_max_position_embeddings *
114
+ self.scaling_factor,
115
+ dtype=jnp.float32)
116
+ freqs = jnp.einsum("...T,k->...Tk", t, inv_freq_H)
117
+ sin, cos = jnp.sin(freqs) * self.mscale, jnp.cos(freqs) * self.mscale
118
+ cache = jnp.concatenate((cos, sin), axis=-1)
119
+ H = cache.shape[1]
120
+ target_dim = ((H - 1) // 128 + 1) * 128
121
+ padding_amount = target_dim - self.rotary_dim
122
+ pad_width = ((0, 0), (0, padding_amount))
123
+ cache_padded = jnp.pad(cache, pad_width, mode='constant')
124
+ desired_layout = Layout(major_to_minor=(1, 0))
125
+ cache_padded = with_layout_constraint(cache_padded, desired_layout)
126
+ return cache_padded
127
+
128
+ def apply_rope(self, positions: jax.Array, x_TNH: jax.Array):
129
+ assert x_TNH.ndim == 3
130
+ assert self.sin_cos_cache is not None, "RoPE cache not initialized."
131
+ cos_sin_padded = self.sin_cos_cache[positions]
132
+ cos_sin_TH = cos_sin_padded[:, :self.rotary_dim]
133
+ # cos, sin: (T, H/2)
134
+ cos_TH, sin_TH = jnp.split(cos_sin_TH, 2, axis=-1)
135
+ assert sin_TH.ndim == 2 and cos_TH.ndim == 2
136
+ # cos, sin: (T, 1, H/2)
137
+ cos_T1H, sin_T1H = cos_TH[:, None, :], sin_TH[:, None, :]
138
+ # even, odd: (T, N, H/2)
139
+ even_TNH, odd_TNH = x_TNH[..., ::2], x_TNH[..., 1::2]
140
+ combined_TNH = jnp.stack([
141
+ even_TNH * cos_T1H - odd_TNH * sin_T1H,
142
+ odd_TNH * cos_T1H + even_TNH * sin_T1H
143
+ ],
144
+ axis=-1).reshape(x_TNH.shape)
145
+ return combined_TNH.astype(self.dtype)
146
+
147
+
148
+ # Calculates the temperature scaling factor for YaRN to adjust
149
+ # RoPE embedding magnitudes.
150
+ def _yarn_get_mscale(scale, mscale):
151
+ return jnp.where(scale <= 1, 1.0, 0.1 * mscale * jnp.log(scale) + 1.0)
152
+
153
+
154
+ # Inverses dim formula to find dim based on number of rotations.
155
+ def _yarn_find_correction_dim(num_rotations,
156
+ dim,
157
+ base=10000,
158
+ max_position_embeddings=2048):
159
+ return (dim * math.log(max_position_embeddings /
160
+ (num_rotations * 2 * math.pi))) / (2 *
161
+ math.log(base))
162
+
163
+
164
+ # Finds dim range bounds based on rotations.
165
+ def _yarn_find_correction_range(low_rot,
166
+ high_rot,
167
+ dim,
168
+ base=10000,
169
+ max_position_embeddings=2048):
170
+ low = math.floor(
171
+ _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
172
+ high = math.ceil(
173
+ _yarn_find_correction_dim(high_rot, dim, base,
174
+ max_position_embeddings))
175
+ return max(low, 0), min(high, dim - 1) # Clamp values just in case
176
+
177
+
178
+ # Creates a 1D mask that ramps linearly from 0 to 1 between min and max indices.
179
+ def _yarn_linear_ramp_mask(min, max, dim):
180
+ if min == max:
181
+ max += 0.001 # Prevent singularity
182
+
183
+ linear_func = (jnp.arange(dim, dtype=jnp.float32) - min) / (max - min)
184
+ ramp_func = jnp.clip(linear_func, 0, 1)
185
+ return ramp_func
186
+
187
+
188
+ @dataclass(kw_only=True)
189
+ class GptOssRotaryEmbedding(nnx.Module):
190
+ """
191
+ JAX implementation of the Rotary Positional Embedding with YaRN scaling.
192
+ """
193
+ head_dim: int
194
+ rope_theta: float
195
+ dtype: jnp.dtype
196
+ initial_context_length: int = 4096
197
+ rope_scaling_factor: float = 1.0
198
+ rope_ntk_alpha: float = 1.0
199
+ rope_ntk_beta: float = 32.0
200
+
201
+ def _compute_concentration_and_inv_freq(self) -> Tuple[float, jax.Array]:
202
+ """
203
+ Computes the inverse frequencies and concentration factor for YaRN.
204
+ See YaRN paper: https://arxiv.org/abs/2309.00071
205
+ """
206
+ freq = self.rope_theta**(
207
+ jnp.arange(0, self.head_dim, 2, dtype=jnp.float32) / self.head_dim)
208
+
209
+ if self.rope_scaling_factor > 1.0:
210
+ concentration = 0.1 * jnp.log(self.rope_scaling_factor) + 1.0
211
+
212
+ d_half = self.head_dim / 2
213
+ # NTK by parts
214
+ low = (d_half * jnp.log(self.initial_context_length /
215
+ (self.rope_ntk_beta * 2 * jnp.pi)) /
216
+ jnp.log(self.rope_theta))
217
+ high = (d_half * jnp.log(self.initial_context_length /
218
+ (self.rope_ntk_alpha * 2 * jnp.pi)) /
219
+ jnp.log(self.rope_theta))
220
+
221
+ interpolation = 1.0 / (self.rope_scaling_factor * freq)
222
+ extrapolation = 1.0 / freq
223
+
224
+ ramp = (jnp.arange(d_half, dtype=jnp.float32) - low) / (high - low)
225
+ mask = 1 - jnp.clip(ramp, 0, 1)
226
+
227
+ inv_freq = interpolation * (1 - mask) + extrapolation * mask
228
+ else:
229
+ concentration = 1.0
230
+ inv_freq = 1.0 / freq
231
+
232
+ return concentration, inv_freq
233
+
234
+ def _compute_cos_sin(self,
235
+ positions: jax.Array) -> Tuple[jax.Array, jax.Array]:
236
+ """Computes cosine and sine embeddings for given positions."""
237
+ concentration, inv_freq_H = self._compute_concentration_and_inv_freq()
238
+
239
+ # freqs: (T, H/2)
240
+ freqs = jnp.einsum("T,H->TH",
241
+ positions.astype(jnp.float32),
242
+ inv_freq_H,
243
+ precision=jax.lax.Precision.HIGHEST)
244
+
245
+ cos = jnp.cos(freqs) * concentration
246
+ sin = jnp.sin(freqs) * concentration
247
+ return cos, sin
248
+
249
+ def __call__(self, query_TNH: jax.Array, key_TNH: jax.Array,
250
+ positions: jax.Array) -> Tuple[jax.Array, jax.Array]:
251
+ """
252
+ Applies rotary embeddings to query and key tensors.
253
+ Args:
254
+ query_TNH: Query tensor with shape (num_tokens, num_heads, head_dim)
255
+ key_TNH: Key tensor with shape (num_tokens, num_kv_heads, head_dim)
256
+ positions: A 1D array of token positions.
257
+ """
258
+ # cos, sin: (T, H/2)
259
+ cos_TH, sin_TH = self._compute_cos_sin(positions)
260
+
261
+ # Reshape for broadcasting: (T, 1, H/2)
262
+ cos_T1H = cos_TH[:, None, :]
263
+ sin_T1H = sin_TH[:, None, :]
264
+
265
+ def _apply_rotation(x_TNH: jax.Array) -> jax.Array:
266
+ # Split the last dimension
267
+ first_half, second_half = jnp.split(x_TNH, 2, axis=-1)
268
+
269
+ # Apply rotation
270
+ rotated_x = jnp.concatenate([
271
+ first_half * cos_T1H - second_half * sin_T1H,
272
+ second_half * cos_T1H + first_half * sin_T1H
273
+ ],
274
+ axis=-1)
275
+ return rotated_x.astype(self.dtype)
276
+
277
+ rotated_query = _apply_rotation(query_TNH)
278
+ rotated_key = _apply_rotation(key_TNH)
279
+
280
+ return rotated_query, rotated_key