tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,199 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import InitVar, dataclass
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ from flax import nnx
20
+ from flax.typing import Sharding
21
+ from jaxtyping import Float
22
+
23
+ from tpu_inference.layers.jax.base import create_param
24
+ from tpu_inference.layers.jax.layers import FlaxUtils
25
+ from tpu_inference.layers.jax.moe.moe import Router
26
+
27
+ modeling_flax_utils = FlaxUtils()
28
+
29
+
30
+ @dataclass(kw_only=True)
31
+ class GptOssRouter(Router):
32
+ """Router module for Mixture-of-Experts (MoE) layers.
33
+
34
+ This module determines which experts each token should be routed.
35
+
36
+ """
37
+ e_sharding: Sharding = ()
38
+
39
+ def __post_init__(self, rngs: nnx.Rngs):
40
+ """
41
+ Initializes the parent's kernel and adds the new bias parameter.
42
+ """
43
+ super().__post_init__(rngs)
44
+
45
+ self.bias_E = create_param(rngs,
46
+ shape=(self.num_experts, ),
47
+ dtype=self.dtype,
48
+ sharding=self.e_sharding,
49
+ random_init=self.random_init)
50
+
51
+ def __call__(self, x_TD: Float):
52
+ """
53
+ Overrides the parent's forward pass to include the bias.
54
+ """
55
+ x_TD = jnp.asarray(x_TD, self.dtype)
56
+ x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
57
+
58
+ router_logits_TE = jnp.einsum('TD,DE -> TE', x_TD,
59
+ self.kernel_DE.value)
60
+
61
+ router_logits_TE += self.bias_E.value
62
+
63
+ weights_TX, selected_experts_TX = jax.lax.top_k(
64
+ router_logits_TE, self.num_experts_per_tok)
65
+
66
+ normalized_weights_TX = jax.nn.softmax(weights_TX.astype(self.dtype),
67
+ axis=-1)
68
+
69
+ return normalized_weights_TX, selected_experts_TX
70
+
71
+
72
+ def _swiglu(x: Float, alpha: Float, limit: Float) -> Float:
73
+ """Implements the specific SwiGLU from the golden implementation."""
74
+ x_glu, x_linear = x[..., ::2], x[..., 1::2]
75
+
76
+ x_glu = jnp.clip(x_glu, a_max=limit)
77
+ x_linear = jnp.clip(x_linear, a_min=-limit, a_max=limit)
78
+
79
+ gated_activation = x_glu * jax.nn.sigmoid(alpha * x_glu)
80
+
81
+ return gated_activation * (x_linear + 1)
82
+
83
+
84
+ @dataclass(kw_only=True)
85
+ class CombineExperts(nnx.Module):
86
+ """Module for combining expert outputs with weighted sum."""
87
+ dtype: jnp.dtype
88
+
89
+ def __call__(self, down_proj_TED: Float, weights_TX: Float,
90
+ indices_TX: jax.Array) -> Float:
91
+ """Combines expert outputs using weighted sum.
92
+
93
+ Args:
94
+ down_proj_TED: Expert outputs, shape (tokens, experts, hidden_dim)
95
+ weights_TX: Router weights, shape (tokens, experts_per_token)
96
+ indices_TX: Selected expert indices, shape (tokens, experts_per_token)
97
+
98
+ Returns:
99
+ Combined output, shape (tokens, hidden_dim)
100
+ """
101
+ with jax.named_scope("combine_experts"):
102
+ indices_for_gather = indices_TX[..., None]
103
+ gathered_down_proj_TED = jnp.take_along_axis(down_proj_TED,
104
+ indices_for_gather,
105
+ axis=1)
106
+ output_TD = jnp.einsum('TXD,TX -> TD', gathered_down_proj_TED,
107
+ weights_TX)
108
+
109
+ return output_TD.astype(self.dtype)
110
+
111
+
112
+ @dataclass(kw_only=True)
113
+ class GptOssMoE(nnx.Module):
114
+ """
115
+ JAX implementation of the GPT-OSS Mixture-of-Experts MLP block.
116
+ """
117
+ dtype: jnp.dtype
118
+ hidden_size: int
119
+ intermediate_size_moe: int
120
+ num_local_experts: int
121
+ router: GptOssRouter
122
+ rngs: InitVar[nnx.Rngs]
123
+
124
+ swiglu_limit: float = 7.0
125
+ swiglu_alpha: float = 1.702
126
+
127
+ # Sharding specifications
128
+ activation_ffw_td: Sharding
129
+ edf_sharding: Sharding
130
+ efd_sharding: Sharding
131
+ ed_sharding: Sharding
132
+
133
+ random_init: bool = False
134
+
135
+ def __call__(self, x_TD: Float) -> Float:
136
+ """Performs the forward pass for the GPT-OSS MoE layer."""
137
+ x_TD = jnp.asarray(x_TD, self.dtype)
138
+ x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
139
+
140
+ weights_TX, indices_TX = self.router(x_TD)
141
+
142
+ # First MLP layer (up-projection)
143
+ with jax.named_scope("MLP #1"):
144
+ up_proj_TEF2 = jnp.einsum('TD,EDF -> TEF', x_TD,
145
+ self.mlp1_weight_EDF2.value)
146
+ up_proj_TEF2 += self.mlp1_bias_EF2.value
147
+
148
+ fuse_TEF = _swiglu(up_proj_TEF2,
149
+ alpha=self.swiglu_alpha,
150
+ limit=self.swiglu_limit)
151
+
152
+ # Second MLP layer (down-projection)
153
+ with jax.named_scope("MLP #2"):
154
+ down_proj_TED = jnp.einsum('TEF,EFD -> TED', fuse_TEF,
155
+ self.mlp2_weight_EFD.value)
156
+ down_proj_TED += self.mlp2_bias_ED.value
157
+
158
+ # Weighted sum of expert outputs
159
+ output_TD = self.combine_experts(down_proj_TED, weights_TX, indices_TX)
160
+
161
+ return output_TD
162
+
163
+ def __post_init__(self, rngs: nnx.Rngs):
164
+ """Initializes all weights and biases for the MoE block."""
165
+ D, F, E = self.hidden_size, self.intermediate_size_moe, self.num_local_experts
166
+
167
+ self.combine_experts = CombineExperts(dtype=self.dtype)
168
+
169
+ # MLP #1 Weights (Combined Gate and Up-projection) and Bias
170
+ self.mlp1_weight_EDF2 = create_param(
171
+ rngs,
172
+ shape=(E, D, F * 2),
173
+ dtype=self.dtype,
174
+ sharding=self.edf_sharding,
175
+ random_init=self.random_init,
176
+ )
177
+ self.mlp1_bias_EF2 = create_param(
178
+ rngs,
179
+ shape=(E, F * 2),
180
+ dtype=self.dtype,
181
+ sharding=self.ed_sharding,
182
+ random_init=self.random_init,
183
+ )
184
+
185
+ # MLP #2 Weights (Down-projection) and Bias
186
+ self.mlp2_weight_EFD = create_param(
187
+ rngs,
188
+ shape=(E, F, D),
189
+ dtype=self.dtype,
190
+ sharding=self.efd_sharding,
191
+ random_init=self.random_init,
192
+ )
193
+ self.mlp2_bias_ED = create_param(
194
+ rngs,
195
+ shape=(E, D),
196
+ dtype=self.dtype,
197
+ sharding=self.ed_sharding,
198
+ random_init=self.random_init,
199
+ )
@@ -0,0 +1,249 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import InitVar, dataclass
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ from flax import nnx
20
+ from flax.typing import Sharding
21
+ from jaxtyping import Float
22
+
23
+ from tpu_inference.layers.jax.base import create_param
24
+ from tpu_inference.layers.jax.layers import FlaxUtils
25
+
26
+ modeling_flax_utils = FlaxUtils()
27
+
28
+
29
+ @dataclass(kw_only=True)
30
+ class CombineExperts(nnx.Module):
31
+ """Combines expert outputs with router weights.
32
+
33
+ Supports `TED,TE -> TD` when passed expert outputs, using float32
34
+ accumulation for numerical stability, then casting back to the target
35
+ dtype.
36
+ """
37
+
38
+ dtype: jnp.dtype
39
+
40
+ def __call__(self, expert_outputs_TED: Float, weights_TE: Float) -> Float:
41
+ with jax.named_scope("combine_experts"):
42
+ output_TD = jnp.einsum(
43
+ "TED,TE -> TD",
44
+ expert_outputs_TED.astype(jnp.float32),
45
+ weights_TE.astype(jnp.float32),
46
+ precision="float32",
47
+ )
48
+
49
+ return output_TD.astype(self.dtype)
50
+
51
+
52
+ @dataclass(kw_only=True)
53
+ class Router(nnx.Module):
54
+ """Router module for Mixture-of-Experts (MoE) layers.
55
+
56
+ This module determines which experts each token should be routed to based on the input.
57
+
58
+ Attributes:
59
+ """
60
+ dtype: jnp.dtype
61
+ hidden_size: int
62
+ num_experts: int
63
+ num_experts_per_tok: int
64
+ router_act: str
65
+ rngs: InitVar[nnx.Rngs]
66
+ activation_ffw_td: Sharding
67
+ ed_sharding: Sharding
68
+ random_init: bool = False
69
+
70
+ def __call__(self, x_TD: Float):
71
+ """Routes tokens to experts.
72
+
73
+ Args:
74
+ x_TD: Input array of shape (sequence_length, d_model).
75
+
76
+ Returns:
77
+ A tuple containing:
78
+ - normalized_weights_TX: Normalized weights for selected experts, shape (sequence_length, num_experts_per_tok).
79
+ - selected_experts_TX: Indices of selected experts, shape (sequence_length, num_experts_per_tok).
80
+ """
81
+ x_TD = jnp.asarray(x_TD, self.dtype)
82
+ x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
83
+ router_act = modeling_flax_utils.ACT2FN[self.router_act]
84
+ router_logits_TE = jnp.einsum('TD,DE -> TE', x_TD,
85
+ self.kernel_DE.value)
86
+ weights_TX, selected_experts_TX = jax.lax.top_k(
87
+ router_logits_TE, self.num_experts_per_tok)
88
+ if self.router_act != "sigmoid": # sigmoid does not accept axis argument.
89
+ normalized_weights_TX = router_act(weights_TX.astype(self.dtype),
90
+ axis=-1)
91
+ else:
92
+ normalized_weights_TX = router_act(weights_TX.astype(self.dtype))
93
+ return normalized_weights_TX, selected_experts_TX
94
+
95
+ def __post_init__(self, rngs: nnx.Rngs):
96
+ """Generates the router kernel (weights) for routing."""
97
+ shape = (self.hidden_size, self.num_experts)
98
+ self.kernel_DE = create_param(rngs,
99
+ shape=shape,
100
+ dtype=self.dtype,
101
+ sharding=self.ed_sharding,
102
+ random_init=self.random_init)
103
+
104
+
105
+ @dataclass(kw_only=True)
106
+ class MoE(nnx.Module):
107
+ """Mixture-of-Experts (MoE) Routed MLP Layer.
108
+
109
+ This module implements a MoE layer with a router and multiple expert MLPs.
110
+
111
+ Attributes:
112
+ router: The Router module.
113
+ """
114
+ dtype: jnp.dtype
115
+ num_local_experts: int
116
+ apply_expert_weight_before_computation: bool
117
+ hidden_size: int
118
+ intermediate_size_moe: int
119
+ hidden_act: str
120
+ rngs: InitVar[nnx.Rngs]
121
+ router: nnx.Module
122
+ activation_ffw_td: Sharding
123
+ activation_ffw_ted: Sharding
124
+ edf_sharding: Sharding
125
+ efd_sharding: Sharding
126
+ random_init: bool = False
127
+
128
+ def __call__(self, x_TD: Float):
129
+ """Performs the forward pass of the MoE layer.
130
+
131
+ Args:
132
+ x_TD: Input array of shape (sequence_length, d_model).
133
+
134
+ Returns:
135
+ Output array of shape (sequence_length, d_model) after passing through MoE.
136
+ """
137
+ x_TD = jnp.asarray(x_TD, self.dtype)
138
+ x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
139
+ weights_TX, indices_TX = self.router(x_TD)
140
+ one_hot_indices_TXE = jax.nn.one_hot(
141
+ indices_TX, num_classes=self.num_local_experts, dtype=self.dtype)
142
+ full_weights_TE = jnp.sum(one_hot_indices_TXE * weights_TX[..., None],
143
+ axis=1)
144
+
145
+ # Some models use the routing scores to weight the data instead of
146
+ # weighting the expert outputs.
147
+ if self.apply_expert_weight_before_computation:
148
+ with jax.named_scope("pre_computing_weight"):
149
+ return self._moe_fwd_preapply_router_weights(
150
+ x_TD, full_weights_TE)
151
+ else:
152
+ return self._moe_fwd(x_TD, full_weights_TE)
153
+
154
+ def __post_init__(self, rngs: nnx.Rngs):
155
+ """Generates the kernels (weights) for the router and experts (gating, up-projection, and down-projection layers)."""
156
+
157
+ D = self.hidden_size
158
+ F = self.intermediate_size_moe
159
+ shape_gating = (self.num_local_experts, D, F)
160
+ shape_up = (self.num_local_experts, D, F)
161
+ shape_down = (self.num_local_experts, F, D)
162
+
163
+ self.kernel_gating_EDF = create_param(rngs,
164
+ shape=shape_gating,
165
+ dtype=self.dtype,
166
+ sharding=self.edf_sharding,
167
+ random_init=self.random_init)
168
+ self.kernel_up_proj_EDF = create_param(rngs,
169
+ shape=shape_up,
170
+ dtype=self.dtype,
171
+ sharding=self.edf_sharding,
172
+ random_init=self.random_init)
173
+ self.kernel_down_proj_EFD = create_param(rngs,
174
+ shape=shape_down,
175
+ dtype=self.dtype,
176
+ sharding=self.efd_sharding,
177
+ random_init=self.random_init)
178
+
179
+ # Shared combine module for combine path
180
+ self.combine_experts = CombineExperts(dtype=self.dtype)
181
+
182
+ def _moe_fwd_preapply_router_weights(self, x_TD: jax.Array, weights_TE):
183
+ """Performs the forward pass of the MoE experts with router weights pre-applied to the inputs.
184
+
185
+ Args:
186
+ x_TD: Input array for the experts, shape (sequence_length, hidden_size).
187
+ weights_TE: Router weights, shape (sequence_length, num_experts).
188
+
189
+ Returns:
190
+ Output array of shape (sequence_length, d_model).
191
+ """
192
+ # Data needs to be replicated since it will be weighted by the router
193
+ # scores before being passed to each expert.
194
+ num_experts = weights_TE.shape[-1]
195
+ x_TED = jnp.repeat(x_TD[:, None, :], num_experts, 1)
196
+ weights_TED = weights_TE[..., None]
197
+ x_TED = jnp.asarray(x_TED, self.dtype)
198
+
199
+ with jax.named_scope("activation_expert_weighting"):
200
+ x_TED = x_TED * weights_TED
201
+
202
+ x_TED = nnx.with_sharding_constraint(x_TED, self.activation_ffw_ted)
203
+ with jax.named_scope("gating"):
204
+ gating_TEF = jnp.einsum('TED,EDF -> TEF', x_TED,
205
+ self.kernel_gating_EDF.value)
206
+ activated_gating_TEF = modeling_flax_utils.ACT2FN[self.hidden_act](
207
+ gating_TEF)
208
+ with jax.named_scope("up_projection"):
209
+ up_proj_TEF = jnp.einsum('TED,EDF -> TEF', x_TED,
210
+ self.kernel_up_proj_EDF.value)
211
+
212
+ fuse_TEF = activated_gating_TEF * up_proj_TEF
213
+
214
+ with jax.named_scope("down_projection"):
215
+ down_proj_TED = jnp.einsum('TEF,EFD -> TED', fuse_TEF,
216
+ self.kernel_down_proj_EFD.value)
217
+ with jax.named_scope("sum"):
218
+ output_TD = down_proj_TED.sum(axis=1)
219
+ return output_TD.astype(self.dtype)
220
+
221
+ def _moe_fwd(self, x_TD: Float, weights):
222
+ """Performs the basic forward pass of the MoE experts without dropping or megablocks.
223
+
224
+ Args:
225
+ x_TD: Input array for the experts, shape (sequence_length, d_model).
226
+ weights: Weights for combining expert outputs, shape (sequence_length, num_experts).
227
+
228
+ Returns:
229
+ Output array of shape (sequence_length, d_model).
230
+ """
231
+ x_TD = jnp.asarray(x_TD, self.dtype)
232
+ x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
233
+ with jax.named_scope("gating"):
234
+ gating_TEF = jnp.einsum('TD,EDF -> TEF', x_TD,
235
+ self.kernel_gating_EDF.value)
236
+ activated_gating_TEF = modeling_flax_utils.ACT2FN[self.hidden_act](
237
+ gating_TEF)
238
+ with jax.named_scope("up_projection"):
239
+ up_proj_TEF = jnp.einsum('TD,EDF -> TEF', x_TD,
240
+ self.kernel_up_proj_EDF.value)
241
+
242
+ fuse_TEF = activated_gating_TEF * up_proj_TEF
243
+
244
+ with jax.named_scope("down_projection"):
245
+ down_proj_TED = jnp.einsum('TEF,EFD -> TED', fuse_TEF,
246
+ self.kernel_down_proj_EFD.value)
247
+ # Combine across experts
248
+ output_TD = self.combine_experts(down_proj_TED, weights)
249
+ return output_TD
@@ -0,0 +1,53 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Protocol
16
+
17
+ from flax import nnx
18
+ from vllm.distributed import get_pp_group
19
+ from vllm.distributed.utils import get_pp_indices
20
+
21
+
22
+ class PPMissingLayer(nnx.Module):
23
+ """
24
+ A placeholder layer for missing layers in a pipeline parallel model.
25
+ """
26
+
27
+ def __init__(self, *args, **kwargs):
28
+ pass
29
+
30
+ def __call__(self, *args, **kwargs):
31
+ """Return the first arg from args or the first value from kwargs."""
32
+ return args[0] if args else next(iter(kwargs.values()))
33
+
34
+
35
+ class LayerFn(Protocol):
36
+
37
+ def __call__(self) -> nnx.Module:
38
+ ...
39
+
40
+
41
+ def make_layers(
42
+ num_hidden_layers: int,
43
+ layer_fn: LayerFn,
44
+ ) -> tuple[int, int, List[nnx.Module]]:
45
+ start_layer, end_layer = get_pp_indices(num_hidden_layers,
46
+ get_pp_group().rank_in_group,
47
+ get_pp_group().world_size)
48
+
49
+ layers = [PPMissingLayer() for _ in range(start_layer)] \
50
+ + [layer_fn() for _ in range(start_layer, end_layer)] \
51
+ + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]
52
+
53
+ return start_layer, end_layer, layers