tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,16 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  """Utility functions for ragged paged attention."""
2
15
  import jax
3
16
  from jax._src import dtypes
@@ -13,7 +26,8 @@ def align_to(x, a):
13
26
 
14
27
 
15
28
  def get_dtype_bitwidth(dtype):
16
- return dtypes.bit_width(dtype)
29
+ return (dtypes.bit_width(dtype)
30
+ if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
17
31
 
18
32
 
19
33
  def get_dtype_packing(dtype):
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -1,10 +1,23 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import functools
2
16
  import math
3
17
  from typing import Any, Callable, Optional, Tuple
4
18
 
5
19
  import jax
6
20
  import jax.numpy as jnp
7
- from jax.experimental import shard_map
8
21
  from jax.experimental.pallas.ops.tpu.paged_attention import paged_attention
9
22
  from jax.experimental.pallas.ops.tpu.splash_attention import \
10
23
  splash_attention_kernel as splash
@@ -55,11 +68,11 @@ def sharded_flash_attention(
55
68
  vmem_limit_bytes=vmem_limit_bytes)
56
69
 
57
70
  return jax.jit(
58
- shard_map.shard_map(_flash_attention,
59
- mesh=mesh,
60
- in_specs=in_specs,
61
- out_specs=out_specs,
62
- check_rep=False))
71
+ jax.shard_map(_flash_attention,
72
+ mesh=mesh,
73
+ in_specs=in_specs,
74
+ out_specs=out_specs,
75
+ check_vma=False))
63
76
 
64
77
 
65
78
  def sharded_paged_attention(
@@ -94,12 +107,12 @@ def sharded_paged_attention(
94
107
  )
95
108
 
96
109
  return jax.jit(
97
- shard_map.shard_map(
110
+ jax.shard_map(
98
111
  _paged_attention_fn,
99
112
  mesh=mesh,
100
113
  in_specs=in_specs,
101
114
  out_specs=out_specs,
102
- check_rep=False,
115
+ check_vma=False,
103
116
  ))
104
117
 
105
118
 
@@ -257,7 +270,7 @@ def sharded_splash_attention(
257
270
  )
258
271
  out_specs = P("data", "model", None, None)
259
272
  return jax.jit(
260
- shard_map.shard_map(
273
+ jax.shard_map(
261
274
  functools.partial(
262
275
  apply_splash,
263
276
  window_size=window_size,
@@ -267,7 +280,7 @@ def sharded_splash_attention(
267
280
  mesh=mesh,
268
281
  in_specs=in_specs,
269
282
  out_specs=out_specs,
270
- check_rep=False,
283
+ check_vma=False,
271
284
  ))
272
285
 
273
286
 
@@ -328,12 +341,12 @@ def sharded_ragged_paged_attention(
328
341
  v_scale=v_scale,
329
342
  )
330
343
 
331
- return shard_map.shard_map(
344
+ return jax.shard_map(
332
345
  _ragged_paged_attention,
333
346
  mesh=mesh,
334
347
  in_specs=in_specs,
335
348
  out_specs=out_specs,
336
- check_rep=False,
349
+ check_vma=False,
337
350
  )(*args)
338
351
 
339
352
 
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import functools
2
16
  from dataclasses import dataclass, field
3
17
  from typing import Any
@@ -0,0 +1,506 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+
17
+ import jax
18
+ from jax import numpy as jnp
19
+ from jax.sharding import Mesh, NamedSharding
20
+ from jax.sharding import PartitionSpec as P
21
+
22
+ from tpu_inference.kernels.megablox.gmm import gmm
23
+ from tpu_inference.layers.common.sharding import ShardingAxisName
24
+ from tpu_inference.layers.common.utils import \
25
+ slice_sharded_tensor_for_concatenation
26
+ from tpu_inference.utils import get_mesh_shape_product
27
+
28
+
29
+ def activation_fn(activation: str, x1: jax.Array, x2: jax.Array) -> jax.Array:
30
+ match activation:
31
+ case "silu":
32
+ return jax.nn.silu(x1) * x2
33
+ case "swigluoai":
34
+ return _swigluoai(x1, x2)
35
+ case _:
36
+ raise NotImplementedError(
37
+ f"FusedMoE does not support {activation} activation")
38
+
39
+
40
+ def _swigluoai(x1: jax.Array,
41
+ x2: jax.Array,
42
+ alpha=1.702,
43
+ limit=7.0) -> jax.Array:
44
+ x1 = jnp.clip(x1, a_max=limit)
45
+ x2 = jnp.clip(x2, a_min=-limit, a_max=limit)
46
+
47
+ gated_activation = x1 * jax.nn.sigmoid(alpha * x1)
48
+
49
+ return gated_activation * (x2 + 1)
50
+
51
+
52
+ def _round_up_to_multiple_of_128_within_limit(x: int, limit: int) -> int:
53
+ """
54
+ Rounds the given integer `x` up to the nearest multiple of 128, without
55
+ exceeding the specified `limit`.
56
+
57
+ If `x` is less than or equal to 128, returns 128.
58
+ If `x` is less than `limit`, returns the smallest multiple of 128 greater
59
+ than or equal to `x`.
60
+ If `x` is greater than or equal to `limit`, searches for the largest
61
+ multiple of 128 less than or equal to `limit` (down to 512) that divides `x`
62
+ evenly, and returns it.
63
+ If no such candidate is found, returns `limit`.
64
+
65
+ Args:
66
+ x (int): The integer to round up.
67
+ limit (int): The upper bound (must be a multiple of 128).
68
+
69
+ Returns:
70
+ int: The rounded value according to the rules above.
71
+
72
+ Raises:
73
+ AssertionError: If `limit` is less than 128 or not a multiple of 128.
74
+ """
75
+ assert limit >= 128 and limit % 128 == 0
76
+ if x <= 128:
77
+ return 128
78
+ if x < limit:
79
+ return (x + 127) // 128 * 128
80
+ for candidate in range(limit, 511, -128):
81
+ if x % candidate == 0:
82
+ return candidate
83
+ return limit
84
+
85
+
86
+ def _get_tiling_size_for_gmm_kernel(m: int, k: int, n: int,
87
+ g: int) -> tuple[int, int, int]:
88
+ """
89
+ Calculate optimal tiling sizes for a GMM kernel in a Mixture of Experts
90
+ (MoE) setting.
91
+
92
+ Args:
93
+ m (int): The total number of tokens.
94
+ n (int): The output feature dimension.
95
+ k (int): The input feature dimension.
96
+ g (int): The number of experts.
97
+
98
+ Returns:
99
+ tuple[int, int, int]: A tuple (tm, tk, tn)
100
+ """
101
+
102
+ # TODO(Chengji): increase the upper limit tiling size of m when we can set
103
+ # the vmem size to be used for gmm kernel.
104
+ # NOTE: In average each expert has m // g tokens, but as it might be
105
+ # unbalanced, here we doubled the token size when choosing tiling size of m.
106
+ # 2m//g can be either greater or less than 512. If there are 32 tokens and
107
+ # topk=2, m=topk * num_tokens=64, in this case, 2*m//g will be less than
108
+ # 512.
109
+ tm = _round_up_to_multiple_of_128_within_limit(2 * m // g, 512)
110
+ tm = min(tm, m) # there's a requirement that m % tm == 0
111
+ # k/n correspond to n_input_features/n_output_features in the matmul so they
112
+ # are normally greater than 2048, unless the num shards is large.
113
+ tk = _round_up_to_multiple_of_128_within_limit(k, 2048)
114
+ tn = _round_up_to_multiple_of_128_within_limit(n, 2048)
115
+ return tm, tk, tn
116
+
117
+
118
+ def tensor_sharded_gmm_merged_column_parallel(
119
+ lhs: jax.Array,
120
+ rhs: jax.Array,
121
+ rhs_scale: jax.Array | None,
122
+ rhs_bias: jax.Array | None,
123
+ group_sizes: jax.Array,
124
+ mesh: Mesh,
125
+ ) -> list[jax.Array]:
126
+
127
+ def _gmm(lhs, rhs, rhs_scale, rhs_bias, group_sizes):
128
+ m, g, n, k = lhs.shape[0], *rhs.shape
129
+ tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
130
+ return gmm(
131
+ lhs,
132
+ rhs,
133
+ group_sizes,
134
+ rhs_scale=rhs_scale,
135
+ rhs_bias=rhs_bias,
136
+ preferred_element_type=lhs.dtype,
137
+ tiling=(tm, tk, tn),
138
+ transpose_rhs=True,
139
+ group_offset=jnp.array(0),
140
+ )
141
+
142
+ rhs_scale_spec = None if rhs_scale is None else P(
143
+ None, None, None, ShardingAxisName.MLP_TENSOR)
144
+ rhs_bias_spec = None if rhs_bias is None else P(
145
+ None, None, ShardingAxisName.MLP_TENSOR)
146
+
147
+ gmm_result = jax.shard_map(
148
+ _gmm,
149
+ mesh=mesh,
150
+ in_specs=(P(ShardingAxisName.MLP_DATA,
151
+ None), P(None, ShardingAxisName.MLP_TENSOR,
152
+ None), rhs_scale_spec, rhs_bias_spec,
153
+ P(ShardingAxisName.MLP_DATA)),
154
+ out_specs=(P(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR)),
155
+ check_vma=False,
156
+ )(lhs, rhs, rhs_scale, rhs_bias, group_sizes)
157
+
158
+ tp_size = get_mesh_shape_product(mesh, ShardingAxisName.MLP_TENSOR)
159
+ intermediate_size = gmm_result.shape[-1] // 2
160
+ output_sizes = [intermediate_size, intermediate_size]
161
+ return slice_sharded_tensor_for_concatenation(gmm_result, output_sizes,
162
+ tp_size)
163
+
164
+
165
+ def tensor_sharded_gmm_row_parallel(
166
+ lhs: jax.Array,
167
+ rhs: jax.Array,
168
+ rhs_scale: jax.Array | None,
169
+ rhs_bias: jax.Array | None,
170
+ group_sizes: jax.Array,
171
+ mesh: Mesh,
172
+ ) -> jax.Array:
173
+
174
+ def _gmm_all_reduce(lhs, rhs, rhs_scale, rhs_bias, group_sizes):
175
+ m, g, n, k = lhs.shape[0], *rhs.shape
176
+ tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
177
+ if rhs_bias is not None:
178
+ shard_id = jax.lax.axis_index(ShardingAxisName.MLP_TENSOR).sum()
179
+ rhs_bias = jnp.where(shard_id == 0, rhs_bias, 0)
180
+ out = gmm(
181
+ lhs,
182
+ rhs,
183
+ group_sizes,
184
+ rhs_scale=rhs_scale,
185
+ rhs_bias=rhs_bias,
186
+ preferred_element_type=lhs.dtype,
187
+ tiling=(tm, tk, tn),
188
+ transpose_rhs=True,
189
+ group_offset=jnp.array(0),
190
+ )
191
+ return jax.lax.psum(out, axis_name=ShardingAxisName.MLP_TENSOR)
192
+
193
+ num_blocks = 1 if rhs_scale is None else rhs_scale.shape[1]
194
+ rhs_scale_spec = None if num_blocks == 1 else P(
195
+ None, ShardingAxisName.MLP_TENSOR, None, None)
196
+ rhs_bias_spec = None if rhs_bias is None else P(None, None, None)
197
+ gmm_result = jax.shard_map(
198
+ _gmm_all_reduce,
199
+ mesh=mesh,
200
+ in_specs=(P(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR),
201
+ P(None, None, ShardingAxisName.MLP_TENSOR), rhs_scale_spec,
202
+ rhs_bias_spec, P(ShardingAxisName.MLP_DATA)),
203
+ out_specs=(P(ShardingAxisName.MLP_DATA)),
204
+ check_vma=False,
205
+ )(lhs, rhs, rhs_scale, rhs_bias, group_sizes)
206
+
207
+ return gmm_result.astype(lhs.dtype)
208
+
209
+
210
+ def expert_sharded_gmm(
211
+ lhs: jax.Array,
212
+ rhs: jax.Array,
213
+ rhs_scale: jax.Array | None,
214
+ rhs_bias: jax.Array | None,
215
+ group_sizes: jax.Array,
216
+ is_last_expert: bool,
217
+ mesh: Mesh,
218
+ ) -> jax.Array:
219
+ ep_size = get_mesh_shape_product(mesh, ShardingAxisName.MLP_TENSOR)
220
+ ep_p_spec = P(ShardingAxisName.EXPERT)
221
+ num_experts = rhs.shape[0]
222
+ num_experts_per_shard = num_experts // ep_size
223
+ group_offset = jnp.arange(0, num_experts, num_experts_per_shard)
224
+
225
+ def _gmm(lhs, rhs, rhs_scale, rhs_bias, group_sizes, group_offset):
226
+ m, g, n, k = lhs.shape[0], *rhs.shape
227
+ tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
228
+
229
+ gmm_res = gmm(
230
+ lhs=lhs,
231
+ rhs=rhs,
232
+ rhs_scale=rhs_scale,
233
+ rhs_bias=rhs_bias,
234
+ group_sizes=group_sizes,
235
+ preferred_element_type=lhs.dtype,
236
+ tiling=(tm, tk, tn),
237
+ transpose_rhs=True,
238
+ group_offset=group_offset[0],
239
+ )
240
+ return gmm_res
241
+
242
+ # The result from gmm on each shard has the same shape, but only the rows
243
+ # for this shard has non-zero values. Taking below as an working example:
244
+ # A, A, A, A 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0
245
+ # A, A, A, A 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0
246
+ # A, A, A, A 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0
247
+ # 0, 0, 0, 0 B, B, B, B 0, 0, 0, 0 0, 0, 0, 0
248
+ # 0, 0, 0, 0 B, B, B, B 0, 0, 0, 0 0, 0, 0, 0
249
+ # 0, 0, 0, 0 0, 0, 0, 0 C, C, C, C 0, 0, 0, 0
250
+ # 0, 0, 0, 0 0, 0, 0, 0 C, C, C, C 0, 0, 0, 0
251
+ # 0, 0, 0, 0 0, 0, 0, 0 C, C, C, C 0, 0, 0, 0
252
+ # 0, 0, 0, 0 0, 0, 0, 0 C, C, C, C 0, 0, 0, 0
253
+ # 0, 0, 0, 0 0, 0, 0, 0 C, C, C, C 0, 0, 0, 0
254
+ # 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0 D, D, D, D
255
+ # 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0 D, D, D, D
256
+ # 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0 D, D, D, D
257
+ # 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0 D, D, D, D
258
+ # shard-0 shard-1 shard-2 shard-3
259
+ # Each shards has 3 (row A), 2 (row B), 5 (row C) and 4 (row D).
260
+ lhs_spec = ep_p_spec if is_last_expert else P()
261
+ rhs_spec = ep_p_spec
262
+ rhs_scale_spec = None if rhs_scale is None else ep_p_spec
263
+ rhs_bias_spec = None if rhs_bias is None else ep_p_spec
264
+ gmm_res = jax.shard_map(
265
+ _gmm,
266
+ mesh=mesh,
267
+ in_specs=(
268
+ lhs_spec,
269
+ rhs_spec,
270
+ rhs_scale_spec,
271
+ rhs_bias_spec,
272
+ P(),
273
+ ep_p_spec,
274
+ ),
275
+ out_specs=ep_p_spec,
276
+ check_vma=False,
277
+ )(lhs, rhs, rhs_scale, rhs_bias, group_sizes, group_offset)
278
+
279
+ if not is_last_expert:
280
+ return gmm_res
281
+
282
+ # For i-th shard, it is responsible groups (AKA experts) from
283
+ # i*num_experts_per_shard to (i+1)*num_experts_per_shard We sum them up to
284
+ # get total rows in that shard, and that is the size for shard to send to
285
+ # its peers. This is also the number of non-zero rows from the gmm results.
286
+ # In the working example, send_sizes would be [3, 2, 5, 4].
287
+
288
+ # group_sizes has shape of [num_tokens_per_shard * num_experts_per_shard].
289
+ # So reshaping to [num_tokens_per_shard, num_experts_per_shard] and applying
290
+ # sum(axis=1) will get desired send_sizes shaped [num_tokens_per_shard].
291
+ send_sizes = group_sizes.reshape(-1, num_experts_per_shard).sum(axis=1)
292
+ # In the working example, input_offsets would be [0, 3, 5, 10]
293
+ input_offsets = jnp.concatenate((jnp.array([0]), send_sizes.cumsum()[:-1]))
294
+ output_offsets = input_offsets
295
+ recv_sizes = send_sizes
296
+
297
+ def _ragged_all_to_all(operand, input_offsets, send_sizes, output_offsets,
298
+ recv_sizes):
299
+ output = jnp.zeros_like(operand)
300
+
301
+ # input_offsets, send_sizes and output_offsets are sharded and there is
302
+ # only 1 elemnt in each shard, we are taking the 0-th element from them
303
+ # just so that jnp.repeat generates the arrays with correct shape.
304
+ input_offsets_of_shard = jnp.repeat(input_offsets[0], ep_size)
305
+ send_sizes_of_shard = jnp.repeat(send_sizes[0], ep_size)
306
+ output_offsets_of_shard = jnp.repeat(output_offsets[0], ep_size)
307
+
308
+ # recv_sizes is replicated across shards, because all the shards receive
309
+ # the same data and write to the output in the same way (same
310
+ # output_offsets and same recv_sizes) and thus generates replicated
311
+ # output.
312
+ recv_sizes_of_shard = recv_sizes
313
+
314
+ # In the working example, for each shard, the values of the offsets and
315
+ # sizes would be:
316
+ # shard-0 shard-1 shard-2 shard-3
317
+ # input_offsets_of_shard [0, 0, 0, 0] [3, 3, 3, 3] [5, 5, 5, 5] [10,10,10,10]
318
+ # send_sizes_of_shard [3, 3, 3, 3] [2, 2, 2, 2] [5, 5, 5, 5] [4, 4, 4, 4 ]
319
+ # output_offsets_of_shard [0, 0, 0, 0] [0, 0, 0, 0] [0, 0, 0, 0] [10,10,10,10]
320
+ # recv_sizes_of_shard [3, 2, 5, 4] [3, 2, 5, 4] [3, 2, 5, 4] [3, 2, 5, 4]
321
+ return jax.lax.ragged_all_to_all(operand,
322
+ output,
323
+ input_offsets_of_shard,
324
+ send_sizes_of_shard,
325
+ output_offsets_of_shard,
326
+ recv_sizes_of_shard,
327
+ axis_name=ShardingAxisName.EXPERT)
328
+
329
+ # Use ragged_all_to_all to send the result from gmm for each expert to all
330
+ # the shards. In the working example, the result would be:
331
+ # A, A, A, A A, A, A, A A, A, A, A A, A, A, A
332
+ # A, A, A, A A, A, A, A A, A, A, A A, A, A, A
333
+ # A, A, A, A A, A, A, A A, A, A, A A, A, A, A
334
+ # B, B, B, B B, B, B, B B, B, B, B B, B, B, B
335
+ # B, B, B, B B, B, B, B B, B, B, B B, B, B, B
336
+ # C, C, C, C C, C, C, C C, C, C, C C, C, C, C
337
+ # C, C, C, C C, C, C, C C, C, C, C C, C, C, C
338
+ # C, C, C, C C, C, C, C C, C, C, C C, C, C, C
339
+ # C, C, C, C C, C, C, C C, C, C, C C, C, C, C
340
+ # C, C, C, C C, C, C, C C, C, C, C C, C, C, C
341
+ # D, D, D, D D, D, D, D D, D, D, D D, D, D, D
342
+ # D, D, D, D D, D, D, D D, D, D, D D, D, D, D
343
+ # D, D, D, D D, D, D, D D, D, D, D D, D, D, D
344
+ # D, D, D, D D, D, D, D D, D, D, D D, D, D, D
345
+ # shard-0 shard-1 shard-2 shard-3
346
+ return jax.shard_map(
347
+ _ragged_all_to_all,
348
+ mesh=mesh,
349
+ in_specs=(ep_p_spec, ep_p_spec, ep_p_spec, ep_p_spec, P()),
350
+ out_specs=(P(ShardingAxisName.MLP_DATA)),
351
+ check_vma=False,
352
+ )(gmm_res, input_offsets, send_sizes, output_offsets, recv_sizes)
353
+
354
+
355
+ @functools.partial(
356
+ jax.jit,
357
+ static_argnames=(
358
+ "topk",
359
+ "renormalize",
360
+ "mesh",
361
+ "use_ep",
362
+ "activation",
363
+ ),
364
+ )
365
+ def fused_moe_func(
366
+ hidden_states: jax.Array,
367
+ w1: jax.Array,
368
+ w2: jax.Array,
369
+ w1_scale: jax.Array | None,
370
+ w2_scale: jax.Array | None,
371
+ w1_bias: jax.Array | None,
372
+ w2_bias: jax.Array | None,
373
+ gating_output: jax.Array,
374
+ topk: int,
375
+ renormalize: bool,
376
+ mesh: Mesh,
377
+ use_ep: bool,
378
+ activation: str,
379
+ ) -> jax.Array:
380
+ """Route tokens in hidden_states into each experts based on routing.
381
+
382
+ Args:
383
+ hidden_states: [num_tokens, hidden_size]
384
+ w1: first moe weights [num_experts, intermediate_size * 2, hidden_size]
385
+ w2: second moe weights [num_experts, hidden_size, intermediate_size]
386
+ w1_scale: w1 scale [num_experts, num_blocks, 1, intermediate_size * 2]
387
+ w2_scale: w2 scale [num_experts, num_blocks, 1, hidden_size]
388
+ w1_bias: optional bias of w1 [num_experts, 1, intermediate_size * 2]
389
+ w2_bias: optional bias of w2 [num_experts, 1, hidden_size]
390
+ gating_output: routing information of tokens [num_tokens, num_experts]
391
+ topk: number of experts to choose per token.
392
+ renormalize: normalize gating_output.
393
+ mesh: mesh to perform moe.
394
+ use_ep: use expert parallelism.
395
+ activation: activation function to perform on the output of w1.
396
+
397
+ Returns:
398
+ Output of moe operation [num_tokens, hidden_size]
399
+ """
400
+ num_tokens, hidden_size = hidden_states.shape
401
+ global_num_experts, _, padded_hidden_size = w1.shape
402
+ dtype = hidden_states.dtype
403
+
404
+ assert (num_tokens * topk) % 16 == 0, (
405
+ "The kernel requires num_tokens * topk to be a multiple of "
406
+ f"16 but got {num_tokens}*{topk}={num_tokens*topk}")
407
+
408
+ assert gating_output.shape == (num_tokens, global_num_experts)
409
+
410
+ topk_weights = jax.nn.softmax(gating_output.astype(jnp.float32), axis=-1)
411
+ # All-gather topk weights for attention dp
412
+ topk_weights = jax.lax.with_sharding_constraint(
413
+ topk_weights, NamedSharding(mesh, P(ShardingAxisName.MLP_DATA, None)))
414
+ topk_weights, topk_indices = jax.lax.top_k(topk_weights, k=topk)
415
+ if renormalize:
416
+ topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdims=True)
417
+ topk_weights = topk_weights.astype(dtype)
418
+
419
+ def _process_tokens_locally(hidden_states_local, topk_indices_local):
420
+ num_tokens_local = hidden_states_local.shape[0]
421
+ topk_indices_flat = topk_indices_local.flatten()
422
+ topk_argsort_indices = jnp.argsort(topk_indices_flat)
423
+ topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
424
+ token_indices = jnp.arange(num_tokens_local,
425
+ dtype=jnp.int32).repeat(topk)
426
+ token_indices_sorted = token_indices[topk_argsort_indices]
427
+ group_sizes_local = jnp.bincount(topk_indices_flat,
428
+ length=global_num_experts)
429
+
430
+ x = hidden_states_local[token_indices_sorted]
431
+ return x, group_sizes_local, topk_argsort_revert_indices
432
+
433
+ x, group_sizes, topk_argsort_revert_indices = jax.shard_map(
434
+ _process_tokens_locally,
435
+ mesh=mesh,
436
+ in_specs=(P(ShardingAxisName.MLP_DATA,
437
+ None), P(ShardingAxisName.MLP_DATA, None)),
438
+ out_specs=(P(ShardingAxisName.MLP_DATA, None),
439
+ P(ShardingAxisName.MLP_DATA), P(ShardingAxisName.MLP_DATA)),
440
+ )(hidden_states, topk_indices)
441
+
442
+ x = jnp.pad(x, ((0, 0), (0, padded_hidden_size - hidden_size)))
443
+
444
+ if use_ep:
445
+ x = expert_sharded_gmm(
446
+ x,
447
+ w1,
448
+ w1_scale,
449
+ w1_bias,
450
+ group_sizes,
451
+ is_last_expert=False,
452
+ mesh=mesh,
453
+ )
454
+ x1, x2 = jnp.split(x, 2, -1)
455
+
456
+ x = activation_fn(activation, x1, x2)
457
+
458
+ x = expert_sharded_gmm(
459
+ x,
460
+ w2,
461
+ w2_scale,
462
+ w2_bias,
463
+ group_sizes,
464
+ is_last_expert=True,
465
+ mesh=mesh,
466
+ )
467
+ else:
468
+ x1, x2 = tensor_sharded_gmm_merged_column_parallel(
469
+ x,
470
+ w1,
471
+ w1_scale,
472
+ w1_bias,
473
+ group_sizes,
474
+ mesh=mesh,
475
+ )
476
+
477
+ x = activation_fn(activation, x1, x2)
478
+
479
+ x = tensor_sharded_gmm_row_parallel(
480
+ x,
481
+ w2,
482
+ w2_scale,
483
+ w2_bias,
484
+ group_sizes,
485
+ mesh=mesh,
486
+ )
487
+
488
+ def _finalize_output(x_local, topk_argsort_revert_indices_local,
489
+ topk_weights_local):
490
+ x_local = x_local[topk_argsort_revert_indices_local].reshape(
491
+ -1, topk, padded_hidden_size)
492
+ x_local = x_local * jnp.expand_dims(topk_weights_local, axis=-1)
493
+ x_local = x_local.sum(axis=-2)
494
+ return x_local
495
+
496
+ x = jax.shard_map(
497
+ _finalize_output,
498
+ mesh=mesh,
499
+ in_specs=(P(ShardingAxisName.MLP_DATA,
500
+ None), P(ShardingAxisName.MLP_DATA),
501
+ P(ShardingAxisName.MLP_DATA, None)),
502
+ out_specs=(P(ShardingAxisName.ATTN_DATA, None)),
503
+ check_vma=False,
504
+ )(x, topk_argsort_revert_indices, topk_weights)
505
+
506
+ return x[:num_tokens, :hidden_size]
@@ -1,7 +1,22 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  UNQUANTIZED = "unquantized"
2
16
  MXFP4 = "mxfp4"
3
17
  AWQ = "awq"
4
18
  COMPRESSED_TENSORS = "compressed-tensors"
19
+ FP8 = "fp8"
5
20
 
6
21
 
7
22
  def get_tpu_quant_method(quant_method: str) -> str: