tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,615 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import enum
16
+ from dataclasses import InitVar, dataclass
17
+ from functools import partial
18
+ from typing import Optional, Tuple
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+ from flax import nnx
23
+ from flax.typing import Sharding
24
+ from jax.sharding import PartitionSpec
25
+ from jaxtyping import Float
26
+ from qwix._src.core.ragged_dot import ragged_dot as qwix_ragged_dot
27
+ from qwix._src.providers import ptq
28
+
29
+ from tpu_inference.layers.jax.base import create_param
30
+ from tpu_inference.layers.jax.layers import FlaxUtils
31
+ from tpu_inference.layers.jax.moe.moe import CombineExperts, MoE
32
+ from tpu_inference.models.jax.utils.qwix.qwix_utils import (
33
+ manually_quantize_qwix_activation, manually_quantize_qwix_weight)
34
+
35
+ modeling_flax_utils = FlaxUtils()
36
+
37
+
38
+ @dataclass
39
+ class DeepSeekV3Router(nnx.Module):
40
+ """Router module for Mixture-of-Experts (MoE) layers.
41
+
42
+ This module determines which experts each token should be routed to based on the input.
43
+
44
+ """
45
+
46
+ hidden_size: int
47
+ num_experts: int
48
+ num_experts_per_tok: int
49
+ n_groups: int
50
+ topk_groups: int
51
+ norm_topk_prob: bool
52
+ routed_scaling_factor: float
53
+ dtype: jnp.dtype
54
+ rngs: InitVar[nnx.Rngs]
55
+
56
+ # Sharding Attributes
57
+ activation_ffw_td: Sharding = ()
58
+ ed_sharding: Sharding = ()
59
+ e_sharding: Sharding = ()
60
+
61
+ random_init: bool = False
62
+
63
+ router_bias_dtype: jnp.dtype = jnp.float32
64
+
65
+ def get_topk_indices(self, scores_TE: Float) -> Float:
66
+ """Get the topk indices of the scores.
67
+
68
+ Args:
69
+ scores_TE: The scores to get the topk indices of. Shape (sequence, num_experts).
70
+
71
+ Returns:
72
+ The topk indices of the scores. Shape (sequence, num_experts_per_tok).
73
+ """
74
+
75
+ scores_TE = scores_TE + self.bias_E
76
+ if self.n_groups > 1:
77
+ experts_per_group = self.num_experts // self.n_groups
78
+ group_scores_TGM = jnp.reshape(
79
+ scores_TE, (-1, self.n_groups, experts_per_group))
80
+ group_scores_TG2 = jax.lax.top_k(group_scores_TGM, k=2)[0]
81
+ group_scores_TG = jnp.sum(group_scores_TG2, axis=-1)
82
+ indices = jax.lax.top_k(group_scores_TG, k=self.topk_groups)[1]
83
+
84
+ mask_TG = jnp.any(jnp.arange(
85
+ self.n_groups)[:, None] == indices[..., None, :],
86
+ axis=-1)
87
+ mask_TE = jnp.repeat(mask_TG,
88
+ scores_TE.shape[-1] // mask_TG.shape[-1], -1)
89
+ scores_TE = jnp.where(mask_TE, scores_TE, 0.0)
90
+
91
+ indices_TX = jax.lax.top_k(scores_TE, k=self.num_experts_per_tok)[1]
92
+
93
+ return indices_TX
94
+
95
+ def __call__(self, x_TD: Float) -> Tuple[Float, Float]:
96
+ """Routes tokens to top k experts.
97
+
98
+ Args:
99
+ x_TD: Input array of shape (sequence, d_model).
100
+
101
+ Returns:
102
+ A tuple containing:
103
+ - weights: Normalized weights for selected experts, shape (sequence, num_experts_per_tok).
104
+ - indices: Indices of selected experts, shape (sequence, num_experts_per_tok).
105
+ """
106
+ x_TD = jnp.asarray(x_TD, self.dtype)
107
+ x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
108
+
109
+ scores_TE = jnp.einsum("TD,DE -> TE", x_TD, self.kernel_DE.value)
110
+ scores_TE = nnx.sigmoid(scores_TE)
111
+
112
+ original_scores_TE = scores_TE
113
+ topk_indices_TX = self.get_topk_indices(scores_TE)
114
+ weights_TX = jnp.take_along_axis(original_scores_TE,
115
+ topk_indices_TX,
116
+ axis=-1)
117
+
118
+ if self.norm_topk_prob:
119
+ weights_TX /= jnp.sum(weights_TX, axis=-1)[..., None] + 1e-20
120
+
121
+ weights_TX *= self.routed_scaling_factor
122
+
123
+ return weights_TX, topk_indices_TX
124
+
125
+ def __post_init__(self, rngs: nnx.Rngs):
126
+ """Generates the router kernel (weights and bias) for routing."""
127
+ D = self.hidden_size
128
+ E = self.num_experts
129
+ self.kernel_DE = create_param(rngs,
130
+ shape=(D, E),
131
+ dtype=self.dtype,
132
+ sharding=self.ed_sharding,
133
+ random_init=self.random_init)
134
+ self.bias_E = create_param(rngs,
135
+ shape=(E, ),
136
+ dtype=self.router_bias_dtype,
137
+ sharding=self.e_sharding,
138
+ random_init=self.random_init)
139
+
140
+
141
+ @dataclass(kw_only=True)
142
+ class SparseMoE(MoE):
143
+ """Mixture-of-Experts (MoE) Routed MLP Layer.
144
+
145
+ This module implements a Sparse MoE layer with a router and multiple expert MLPs.
146
+
147
+ Attributes:
148
+ num_experts_per_tok: The number of experts each token is routed to.
149
+ tile_size: A tuple (batch, activation_dim, weight_dim) for GMM tiling.
150
+ use_megablox: If True, uses the MegaBlox GMM kernel.
151
+ mesh: The device mesh.
152
+ # TODO: need to redesign this I/O for parallelism
153
+ num_expert_parallelism: The size of the 'expert' mesh dimension.
154
+ # TODO: determine if we get it from external or extrat it in MoE class
155
+ is_batch_sharded_by_expert: True if batch is sharded over 'expert' dim.
156
+ """
157
+ num_experts_per_tok: int
158
+ #TODO: tile size is (tile_batch_seq, tile_activation_dim, tile_weight_dim,) from MaxText
159
+ tile_size: tuple[int, int, int] = (128, 64, 128)
160
+ use_megablox: bool = False
161
+ mesh: jax.sharding.Mesh
162
+ # This should be set if and only if you have quantized your model (via Qwix)
163
+ quantized_dtype: Optional[jnp.dtype] = None
164
+
165
+ def __post_init__(self, rngs: nnx.Rngs):
166
+ super().__post_init__(rngs)
167
+ self.combine_experts = CombineExperts(dtype=self.dtype)
168
+
169
+ # Derive the expert sharding
170
+ self.expert_axis_name = self.edf_sharding[0]
171
+ if self.expert_axis_name is None:
172
+ self.num_expert_parallelism = 1
173
+ else:
174
+ self.num_expert_parallelism = self.mesh.shape[
175
+ self.expert_axis_name]
176
+
177
+ # Derive if data is sharded by expert
178
+ self.data_axis_name = self.activation_ffw_td[0]
179
+ self.is_batch_sharded_by_expert = (
180
+ self.expert_axis_name is not None) and (self.expert_axis_name
181
+ == self.data_axis_name)
182
+
183
+ def _sort_activations(self, inputs: jax.Array,
184
+ sort_indices: jax.Array) -> jax.Array:
185
+ """Sorts activations(inputs) by `sort_indices` for the forward pass."""
186
+ return inputs[sort_indices, ...]
187
+
188
+ @staticmethod
189
+ def get_all_to_all_params(
190
+ all_shards_group_sizes,
191
+ shard_id,
192
+ num_expert_parallelism,
193
+ is_batch_sharded=True,
194
+ ):
195
+ """Generates params for ragged_all_to_all communication."""
196
+
197
+ class TransformStrategy(enum.Enum):
198
+ INPUT_OFFSET = enum.auto()
199
+ SEND_SIZE = enum.auto()
200
+ OUTPUT_OFFSET = enum.auto()
201
+ RECV_SIZE = enum.auto()
202
+
203
+ def transform_array(input_array, shard_id, strategy, is_batch_sharded):
204
+ if is_batch_sharded:
205
+ if strategy == TransformStrategy.INPUT_OFFSET:
206
+ local_array = input_array[shard_id]
207
+ return jnp.concatenate(
208
+ (jnp.array([0]), jnp.cumsum(local_array)[:-1]))
209
+ elif strategy == TransformStrategy.SEND_SIZE:
210
+ return input_array[shard_id]
211
+ elif strategy == TransformStrategy.OUTPUT_OFFSET:
212
+ zero_row = jnp.zeros((1, ) + input_array.shape[1:],
213
+ dtype=input_array.dtype)
214
+ array_with_zeros = jnp.concatenate((zero_row, input_array),
215
+ axis=0)
216
+ cumulated_array = jnp.cumsum(array_with_zeros,
217
+ axis=0,
218
+ dtype=input_array.dtype)
219
+ return cumulated_array[shard_id]
220
+ elif strategy == TransformStrategy.RECV_SIZE:
221
+ return input_array[:, shard_id]
222
+ else:
223
+ raise ValueError(
224
+ f"Unknown transform array strategy: {strategy}")
225
+ else:
226
+ if strategy == TransformStrategy.INPUT_OFFSET:
227
+ return jnp.zeros(num_expert_parallelism,
228
+ dtype=input_array.dtype)
229
+ elif strategy == TransformStrategy.SEND_SIZE:
230
+ return jnp.repeat(input_array[shard_id],
231
+ num_expert_parallelism)
232
+ elif strategy == TransformStrategy.OUTPUT_OFFSET:
233
+ output_offset = jnp.concatenate(
234
+ (jnp.array([0]),
235
+ jnp.cumsum(input_array[:-1])))[shard_id]
236
+ return jnp.repeat(output_offset, num_expert_parallelism)
237
+ elif strategy == TransformStrategy.RECV_SIZE:
238
+ return input_array
239
+ else:
240
+ raise ValueError(
241
+ f"Unknown transform array strategy: {strategy}")
242
+
243
+ input_offsets = transform_array(all_shards_group_sizes, shard_id,
244
+ TransformStrategy.INPUT_OFFSET,
245
+ is_batch_sharded)
246
+ send_sizes = transform_array(all_shards_group_sizes, shard_id,
247
+ TransformStrategy.SEND_SIZE,
248
+ is_batch_sharded)
249
+ output_offsets = transform_array(all_shards_group_sizes, shard_id,
250
+ TransformStrategy.OUTPUT_OFFSET,
251
+ is_batch_sharded)
252
+ recv_sizes = transform_array(all_shards_group_sizes, shard_id,
253
+ TransformStrategy.RECV_SIZE,
254
+ is_batch_sharded)
255
+ return input_offsets, send_sizes, output_offsets, recv_sizes
256
+
257
+ def _local_permute(
258
+ self,
259
+ inputs,
260
+ global_group_sizes,
261
+ local_expert_size,
262
+ shard_index,
263
+ is_offset=False,
264
+ global_sorted_experts=None,
265
+ ):
266
+ """Permutes tokens locally within an expert shard."""
267
+ # global_group_sizes: (tokens parallelism, num_total_experts)
268
+ # all_shard_local_sizes: (tokens parallelism, num local experts in the shard)
269
+ all_shard_local_sizes = jax.lax.dynamic_slice_in_dim(
270
+ global_group_sizes,
271
+ shard_index * local_expert_size,
272
+ local_expert_size,
273
+ axis=1,
274
+ )
275
+ local_sizes = all_shard_local_sizes.reshape(-1)
276
+
277
+ # local_group_size: (tokens parallelism, )
278
+ local_group_size = jnp.sum(all_shard_local_sizes, axis=0)
279
+
280
+ # When token replicated in devices
281
+ if is_offset:
282
+ global_sorted_shard_assignments = jnp.floor_divide(
283
+ global_sorted_experts, local_expert_size)
284
+ expert_indices = jnp.where(
285
+ global_sorted_shard_assignments == shard_index,
286
+ jnp.mod(global_sorted_experts, local_expert_size),
287
+ local_expert_size,
288
+ )
289
+
290
+ # When token sharded in devices
291
+ else:
292
+ base_indices = jnp.mod(jnp.arange(local_sizes.shape[0]),
293
+ local_expert_size)
294
+ expert_indices = jnp.repeat(base_indices,
295
+ local_sizes,
296
+ total_repeat_length=inputs.shape[0])
297
+
298
+ sorted_indices = jnp.argsort(expert_indices)
299
+ # sort the inputs based on the local expert_indices
300
+ sorted_inputs = self._sort_activations(inputs, sorted_indices)
301
+ # sortted local expert id from 0 to local expert size
302
+ sorted_experts_ids = expert_indices[sorted_indices]
303
+ return (
304
+ sorted_inputs,
305
+ sorted_indices,
306
+ local_group_size,
307
+ sorted_experts_ids,
308
+ )
309
+
310
+ def _permute(self, inputs_TD: Float, selected_experts_TX: jax.Array):
311
+ """Global permute: Sorts tokens by assigned expert."""
312
+ # suffix t = T * X = total_assignments for the local tokens(T) on this device.
313
+ total_tokens = inputs_TD.shape[0]
314
+ flat_expert_indices = selected_experts_TX.flatten()
315
+ sort_indices_t = jnp.argsort(flat_expert_indices)
316
+
317
+ replicated_inputs_tD = jnp.repeat(inputs_TD,
318
+ self.num_experts_per_tok,
319
+ axis=0)
320
+ sorted_inputs_tD = self._sort_activations(replicated_inputs_tD,
321
+ sort_indices_t)
322
+
323
+ # number of tokens assigned to each expert
324
+ group_sizes_E = jnp.bincount(flat_expert_indices,
325
+ length=self.num_local_experts)
326
+
327
+ expert_ids = jnp.arange(self.num_local_experts)
328
+ total_assignments = total_tokens * self.num_experts_per_tok
329
+ sorted_expert_assignments_t = jnp.repeat(
330
+ expert_ids,
331
+ repeats=group_sizes_E,
332
+ total_repeat_length=total_assignments)
333
+
334
+ return (
335
+ sorted_inputs_tD,
336
+ sort_indices_t,
337
+ group_sizes_E,
338
+ sorted_expert_assignments_t,
339
+ )
340
+
341
+ def _unpermute(self, processed_tokens: jax.Array, sort_indices: jax.Array,
342
+ router_weights_TX: jax.Array):
343
+ """Unsorts tokens to their original order and combines expert outputs with router's weight."""
344
+ with jax.named_scope("unpermute"):
345
+ unsorted_tokens_tD = self._sort_activations(
346
+ processed_tokens, jnp.argsort(sort_indices))
347
+ reshaped_tokens_TXD = unsorted_tokens_tD.reshape(
348
+ -1, self.num_experts_per_tok, self.hidden_size)
349
+ return self.combine_experts(reshaped_tokens_TXD, router_weights_TX)
350
+
351
+ def _gmm(self, inputs, kernel, group_sizes):
352
+ """Performs Grouped Matrix Multiply."""
353
+ num_rows = inputs.shape[0]
354
+ pad_amount = (self.tile_size[0] -
355
+ num_rows % self.tile_size[0]) % self.tile_size[0]
356
+ if pad_amount > 0:
357
+ inputs = jnp.pad(inputs, ((0, pad_amount), (0, 0)))
358
+
359
+ if self.use_megablox:
360
+ #TODO: megablox is used in MaxText, keep a placeholder here for future implement
361
+ raise NotImplementedError(
362
+ "MegaBlox kernel call is not implemented.")
363
+ else:
364
+ inputs = manually_quantize_qwix_activation(
365
+ inputs, "ragged_dot", jnp.float8_e4m3fn, [0], {},
366
+ "absmax") if self.quantized_dtype else inputs
367
+ ragged_dot_func = qwix_ragged_dot if self.quantized_dtype else jax.lax.ragged_dot
368
+ output = ragged_dot_func(
369
+ lhs=inputs,
370
+ rhs=kernel,
371
+ group_sizes=group_sizes,
372
+ preferred_element_type=self.dtype,
373
+ )
374
+
375
+ if pad_amount > 0:
376
+ output = output[:num_rows, :]
377
+ return output
378
+
379
+ @staticmethod
380
+ def _distributed_sparse_moe_fwd(
381
+ self,
382
+ x_TD: jax.Array,
383
+ router_weights_TX: jax.Array,
384
+ selected_experts_TX: jax.Array,
385
+ kernel_gating: jax.Array,
386
+ kernel_up_proj: jax.Array,
387
+ kernel_down_proj: jax.Array,
388
+ ):
389
+ """
390
+ The sparse MoE forward pass with fully distributed logic.
391
+ This assumes it is running within a distributed TPU.
392
+ """
393
+
394
+ # 1. Global Permute, perpute all tokens across shards
395
+ (
396
+ sorted_inputs,
397
+ global_sort_indices,
398
+ global_group_sizes,
399
+ global_sorted_experts,
400
+ ) = self._permute(x_TD, selected_experts_TX)
401
+
402
+ # TODO: update to 'expert' after we enable expert parallelism, currently experts are sharded along model axis
403
+ # or we sould derive it from the model init
404
+ expert_shard_id = jax.lax.axis_index(self.expert_axis_name)
405
+ local_expert_size = self.num_local_experts // self.num_expert_parallelism
406
+
407
+ if self.num_expert_parallelism > 1:
408
+ if self.is_batch_sharded_by_expert:
409
+ # When token sharded in devices
410
+ # In this path, we assume the data(tokens) are fully sharded on expert, namely data_axis_name == expert_axis_name
411
+
412
+ # 2a. Send Tokens To Experts (All-to-All)
413
+ # Gather group sizes from all data shards
414
+ # all_shards_group_sizes: (data parallelism = expert parallelism, number of total experts )
415
+ all_shards_group_sizes = jax.lax.all_gather(
416
+ global_group_sizes, axis_name=self.data_axis_name)
417
+
418
+ # all_shards_group_sizes_per_expert_shard[i][j] = # tokens on shard[i] to be sent to expert shard[j]
419
+ all_shards_group_sizes_per_expert_shard = jnp.sum(
420
+ all_shards_group_sizes.reshape(
421
+ self.num_expert_parallelism, # data parallelism
422
+ self.num_expert_parallelism, # expert parallelism
423
+ local_expert_size # Experts per shard
424
+ ),
425
+ axis=2)
426
+ input_offsets, send_sizes, output_offsets, recv_sizes = self.get_all_to_all_params(
427
+ all_shards_group_sizes_per_expert_shard, expert_shard_id,
428
+ self.num_expert_parallelism)
429
+ # Estimate buffer size
430
+ local_total_assignments = x_TD.shape[
431
+ 0] * self.num_experts_per_tok
432
+ global_total_assignments = local_total_assignments * self.num_expert_parallelism
433
+ output_shape_est = jnp.zeros(
434
+ (global_total_assignments, self.hidden_size),
435
+ dtype=sorted_inputs.dtype)
436
+
437
+ inputs_after_all2all = jax.lax.ragged_all_to_all(
438
+ sorted_inputs,
439
+ output_shape_est,
440
+ input_offsets,
441
+ send_sizes,
442
+ output_offsets,
443
+ recv_sizes,
444
+ axis_name=self.expert_axis_name)
445
+
446
+ # 3a. Local Permute
447
+ # Get full group sizes from all shards
448
+ full_global_group_sizes = jax.lax.all_gather(
449
+ global_group_sizes, axis_name=self.expert_axis_name)
450
+ (
451
+ compute_inputs,
452
+ local_sorted_indices,
453
+ compute_group_sizes,
454
+ compute_expert_ids,
455
+ ) = self._local_permute(
456
+ inputs_after_all2all,
457
+ full_global_group_sizes,
458
+ local_expert_size,
459
+ shard_index=expert_shard_id,
460
+ is_offset=False,
461
+ )
462
+
463
+ else:
464
+ # When token replicated in devices
465
+
466
+ # 2. No send all-to-all needed, as the tokens are sorted and replicated on all devices
467
+ # 3b. Local "Permute"
468
+ (
469
+ compute_inputs,
470
+ local_sorted_indices,
471
+ compute_group_sizes,
472
+ compute_expert_ids,
473
+ ) = self._local_permute(
474
+ sorted_inputs,
475
+ global_group_sizes[None, :],
476
+ local_expert_size,
477
+ shard_index=expert_shard_id,
478
+ is_offset=True,
479
+ global_sorted_experts=global_sorted_experts,
480
+ )
481
+
482
+ # Calculate group sizes for return all-to-all
483
+ reshaped_group_sizes = jnp.sum(global_group_sizes.reshape(
484
+ -1, local_expert_size),
485
+ axis=1)
486
+ mask = compute_expert_ids < local_expert_size
487
+ compute_inputs = compute_inputs * mask[..., None]
488
+
489
+ else:
490
+ # --- NO EXPERT PARALLELISM ---
491
+ compute_inputs = sorted_inputs
492
+ compute_group_sizes = global_group_sizes
493
+ compute_expert_ids = global_sorted_experts
494
+ local_sorted_indices = jnp.arange(sorted_inputs.shape[0])
495
+
496
+ # 4. Compute: Apply experts using Grouped Matrix Multiply
497
+ with jax.named_scope("gating"):
498
+ # compute_inputs: (local total assignments, D)
499
+ gating_TEF = self._gmm(compute_inputs, kernel_gating,
500
+ compute_group_sizes)
501
+ activated_gating_TEF = modeling_flax_utils.ACT2FN[self.hidden_act](
502
+ gating_TEF)
503
+
504
+ with jax.named_scope("up_projection"):
505
+ up_proj_TEF = self._gmm(compute_inputs, kernel_up_proj,
506
+ compute_group_sizes)
507
+
508
+ fuse_TEF = activated_gating_TEF * up_proj_TEF
509
+
510
+ with jax.named_scope("down_projection"):
511
+ # intermediate_output: (local total assignments, D)
512
+ intermediate_output = self._gmm(fuse_TEF, kernel_down_proj,
513
+ compute_group_sizes)
514
+
515
+ # 5. Return Results (All-to-All)
516
+ if self.num_expert_parallelism > 1:
517
+ local_total_assignments = x_TD.shape[0] * self.num_experts_per_tok
518
+ output_shape = jnp.zeros(
519
+ (local_total_assignments, self.hidden_size),
520
+ dtype=intermediate_output.dtype)
521
+
522
+ if self.is_batch_sharded_by_expert:
523
+ # When token sharded in devices
524
+ # Unsort locally before sending back
525
+ local_output = self._sort_activations(
526
+ intermediate_output, jnp.argsort(local_sorted_indices))
527
+
528
+ input_offsets, send_sizes, output_offsets, recv_sizes = self.get_all_to_all_params(
529
+ jnp.transpose(all_shards_group_sizes),
530
+ expert_shard_id,
531
+ self.num_expert_parallelism,
532
+ )
533
+ final_intermediate_output = jax.lax.ragged_all_to_all(
534
+ local_output,
535
+ output_shape,
536
+ input_offsets,
537
+ send_sizes,
538
+ output_offsets,
539
+ recv_sizes,
540
+ axis_name=self.expert_axis_name)
541
+ else:
542
+ # When token replicated in devices
543
+ input_offsets, send_sizes, output_offsets, recv_sizes = self.get_all_to_all_params(
544
+ reshaped_group_sizes,
545
+ expert_shard_id,
546
+ self.num_expert_parallelism,
547
+ is_batch_sharded=False,
548
+ )
549
+ final_intermediate_output = jax.lax.ragged_all_to_all(
550
+ intermediate_output,
551
+ output_shape,
552
+ input_offsets,
553
+ send_sizes,
554
+ output_offsets,
555
+ recv_sizes,
556
+ axis_name=self.expert_axis_name)
557
+ else:
558
+ final_intermediate_output = intermediate_output
559
+
560
+ # 6. Global Unpermute (on the data shard)
561
+ with jax.named_scope("unpermute"):
562
+ output_TD = self._unpermute(final_intermediate_output,
563
+ global_sort_indices, router_weights_TX)
564
+
565
+ return output_TD
566
+
567
+ def __call__(self, x_TD: Float):
568
+ """Performs the forward pass of the Sparse MoE layer."""
569
+ x_TD = jnp.asarray(x_TD, self.dtype)
570
+ x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
571
+ router_weights_TX, selected_experts_TX = self.router(x_TD)
572
+
573
+ in_specs = (
574
+ PartitionSpec(), # Replicated `self`
575
+ PartitionSpec(*self.activation_ffw_td), # Sharded x_TD
576
+ PartitionSpec(), # Replicated router_weights_TX
577
+ PartitionSpec(), # Replicated selected_experts_TX
578
+ PartitionSpec(*self.edf_sharding), # Sharded gating kernel
579
+ PartitionSpec(*self.edf_sharding), # Sharded up-projection kernel
580
+ PartitionSpec(
581
+ *self.efd_sharding), # Sharded down-projection kernel
582
+ )
583
+ out_specs = PartitionSpec(*self.activation_ffw_td)
584
+
585
+ mapped_moe_fwd = partial(jax.shard_map,
586
+ mesh=self.mesh,
587
+ in_specs=in_specs,
588
+ out_specs=out_specs,
589
+ check_vma=False)(
590
+ SparseMoE._distributed_sparse_moe_fwd)
591
+
592
+ kernel_gating_EDF = self.kernel_gating_EDF.value
593
+ kernel_up_proj_EDF = self.kernel_up_proj_EDF.value
594
+ kernel_down_proj_EFD = self.kernel_down_proj_EFD.value
595
+
596
+ if self.quantized_dtype:
597
+ if not isinstance(kernel_gating_EDF, ptq.WithAux):
598
+ kernel_gating_EDF = manually_quantize_qwix_weight(
599
+ kernel_gating_EDF, self.quantized_dtype, [0, 2], {},
600
+ "absmax")
601
+ if not isinstance(kernel_up_proj_EDF, ptq.WithAux):
602
+ kernel_up_proj_EDF = manually_quantize_qwix_weight(
603
+ kernel_up_proj_EDF, self.quantized_dtype, [0, 2], {},
604
+ "absmax")
605
+ if not isinstance(kernel_down_proj_EFD, ptq.WithAux):
606
+ kernel_down_proj_EFD = manually_quantize_qwix_weight(
607
+ kernel_down_proj_EFD, self.quantized_dtype, [0, 1], {},
608
+ "absmax")
609
+ kernel_gating_EDF = kernel_gating_EDF.array
610
+ kernel_up_proj_EDF = kernel_up_proj_EDF.array
611
+ kernel_down_proj_EFD = kernel_down_proj_EFD.array
612
+
613
+ return mapped_moe_fwd(self, x_TD, router_weights_TX,
614
+ selected_experts_TX, kernel_gating_EDF,
615
+ kernel_up_proj_EDF, kernel_down_proj_EFD)