tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (251) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +22 -1
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +31 -9
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +77 -36
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -71
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +158 -63
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +53 -30
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +54 -2
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +105 -57
  232. tpu_inference/runner/utils.py +2 -2
  233. tpu_inference/spec_decode/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/__init__.py +13 -0
  235. tpu_inference/spec_decode/jax/eagle3.py +65 -19
  236. tpu_inference/tpu_info.py +14 -0
  237. tpu_inference/utils.py +72 -44
  238. tpu_inference/worker/__init__.py +13 -0
  239. tpu_inference/worker/tpu_worker.py +65 -52
  240. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  241. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  242. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  244. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  245. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  246. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  247. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  248. tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
  249. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  250. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  251. {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,32 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import functools
2
16
 
3
17
  import jax
4
18
  from jax import numpy as jnp
5
- from jax.experimental.pallas.ops.tpu.megablox.gmm import gmm
6
- from jax.experimental.shard_map import shard_map
7
- from jax.sharding import Mesh, NamedSharding, PartitionSpec
19
+ from jax.sharding import Mesh, NamedSharding
20
+ from jax.sharding import PartitionSpec as P
8
21
 
22
+ from tpu_inference.kernels.megablox.gmm import gmm
23
+ from tpu_inference.layers.common.sharding import ShardingAxisName
9
24
  from tpu_inference.layers.vllm.linear_common import \
10
25
  slice_sharded_tensor_for_concatenation
11
-
12
- P = PartitionSpec
26
+ from tpu_inference.utils import get_mesh_shape_product
13
27
 
14
28
 
15
- def activation_fn(activation: str, x1, x2):
29
+ def activation_fn(activation: str, x1: jax.Array, x2: jax.Array) -> jax.Array:
16
30
  match activation:
17
31
  case "silu":
18
32
  return jax.nn.silu(x1) * x2
@@ -23,7 +37,10 @@ def activation_fn(activation: str, x1, x2):
23
37
  f"FusedMoE does not support {activation} activation")
24
38
 
25
39
 
26
- def _swigluoai(x1, x2, alpha=1.702, limit=7.0):
40
+ def _swigluoai(x1: jax.Array,
41
+ x2: jax.Array,
42
+ alpha=1.702,
43
+ limit=7.0) -> jax.Array:
27
44
  x1 = jnp.clip(x1, a_max=limit)
28
45
  x2 = jnp.clip(x2, a_min=-limit, a_max=limit)
29
46
 
@@ -101,142 +118,124 @@ def _get_tiling_size_for_gmm_kernel(m: int, k: int, n: int,
101
118
  def tensor_sharded_gmm_merged_column_parallel(
102
119
  lhs: jax.Array,
103
120
  rhs: jax.Array,
121
+ rhs_scale: jax.Array | None,
104
122
  rhs_bias: jax.Array | None,
105
123
  group_sizes: jax.Array,
106
- transpose_rhs: bool,
107
124
  mesh: Mesh,
108
- intermediate_size: int,
109
- ) -> jax.Array:
110
- # adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
111
- m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
112
- n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
113
- tm, tk, tn = _get_tiling_size_for_gmm_kernel(m // mesh.shape["data"], k, n,
114
- g)
115
-
116
- _gmm = functools.partial(
117
- gmm,
118
- preferred_element_type=lhs.dtype,
119
- tiling=(tm, tk, tn),
120
- transpose_rhs=transpose_rhs,
121
- group_offset=jnp.array(0),
122
- )
123
-
124
- gmm_result = shard_map(
125
+ ) -> list[jax.Array]:
126
+
127
+ def _gmm(lhs, rhs, rhs_scale, rhs_bias, group_sizes):
128
+ m, g, n, k = lhs.shape[0], *rhs.shape
129
+ tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
130
+ return gmm(
131
+ lhs,
132
+ rhs,
133
+ group_sizes,
134
+ rhs_scale=rhs_scale,
135
+ rhs_bias=rhs_bias,
136
+ preferred_element_type=lhs.dtype,
137
+ tiling=(tm, tk, tn),
138
+ transpose_rhs=True,
139
+ group_offset=jnp.array(0),
140
+ )
141
+
142
+ rhs_scale_spec = None if rhs_scale is None else P(
143
+ None, None, None, ShardingAxisName.MLP_TENSOR)
144
+ rhs_bias_spec = None if rhs_bias is None else P(
145
+ None, None, ShardingAxisName.MLP_TENSOR)
146
+
147
+ gmm_result = jax.shard_map(
125
148
  _gmm,
126
149
  mesh=mesh,
127
- in_specs=(P("data", None), P(None, "model", None), P("data")),
128
- out_specs=(P("data", "model")),
129
- check_rep=False,
130
- )(lhs, rhs, group_sizes)
131
-
132
- if rhs_bias is not None:
133
-
134
- def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
135
- rhs_bis = jnp.repeat(rhs_bias_local,
136
- group_sizes_global,
137
- 0,
138
- total_repeat_length=m // mesh.shape["data"])
139
- return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)
140
-
141
- gmm_result = shard_map(
142
- _add_bias,
143
- mesh=mesh,
144
- in_specs=(P("data", "model"), P(None, "model"), P("data")),
145
- out_specs=(P("data", "model")),
146
- )(gmm_result, rhs_bias, group_sizes)
147
-
148
- n_shards = mesh.shape["model"]
150
+ in_specs=(P(ShardingAxisName.MLP_DATA,
151
+ None), P(None, ShardingAxisName.MLP_TENSOR,
152
+ None), rhs_scale_spec, rhs_bias_spec,
153
+ P(ShardingAxisName.MLP_DATA)),
154
+ out_specs=(P(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR)),
155
+ check_vma=False,
156
+ )(lhs, rhs, rhs_scale, rhs_bias, group_sizes)
157
+
158
+ tp_size = get_mesh_shape_product(mesh, ShardingAxisName.MLP_TENSOR)
159
+ intermediate_size = gmm_result.shape[-1] // 2
149
160
  output_sizes = [intermediate_size, intermediate_size]
150
-
151
161
  return slice_sharded_tensor_for_concatenation(gmm_result, output_sizes,
152
- n_shards)
162
+ tp_size)
153
163
 
154
164
 
155
165
  def tensor_sharded_gmm_row_parallel(
156
166
  lhs: jax.Array,
157
167
  rhs: jax.Array,
168
+ rhs_scale: jax.Array | None,
158
169
  rhs_bias: jax.Array | None,
159
170
  group_sizes: jax.Array,
160
- transpose_rhs: bool,
161
171
  mesh: Mesh,
162
172
  ) -> jax.Array:
163
- # adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
164
- m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
165
- n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
166
- tm, tk, tn = _get_tiling_size_for_gmm_kernel(m // mesh.shape["data"], k, n,
167
- g)
168
-
169
- _gmm = functools.partial(
170
- gmm,
171
- preferred_element_type=lhs.dtype,
172
- tiling=(tm, tk, tn),
173
- transpose_rhs=transpose_rhs,
174
- group_offset=jnp.array(0),
175
- )
176
-
177
- def _gmm_all_reduce(lhs, rhs, group_sizes):
178
- r = _gmm(lhs, rhs, group_sizes)
179
- return jax.lax.psum(r, axis_name="model")
180
-
181
- gmm_result = shard_map(
173
+
174
+ def _gmm_all_reduce(lhs, rhs, rhs_scale, rhs_bias, group_sizes):
175
+ m, g, n, k = lhs.shape[0], *rhs.shape
176
+ tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
177
+ if rhs_bias is not None:
178
+ shard_id = jax.lax.axis_index(ShardingAxisName.MLP_TENSOR).sum()
179
+ rhs_bias = jnp.where(shard_id == 0, rhs_bias, 0)
180
+ out = gmm(
181
+ lhs,
182
+ rhs,
183
+ group_sizes,
184
+ rhs_scale=rhs_scale,
185
+ rhs_bias=rhs_bias,
186
+ preferred_element_type=lhs.dtype,
187
+ tiling=(tm, tk, tn),
188
+ transpose_rhs=True,
189
+ group_offset=jnp.array(0),
190
+ )
191
+ return jax.lax.psum(out, axis_name=ShardingAxisName.MLP_TENSOR)
192
+
193
+ num_blocks = 1 if rhs_scale is None else rhs_scale.shape[1]
194
+ rhs_scale_spec = None if num_blocks == 1 else P(
195
+ None, ShardingAxisName.MLP_TENSOR, None, None)
196
+ rhs_bias_spec = None if rhs_bias is None else P(None, None, None)
197
+ gmm_result = jax.shard_map(
182
198
  _gmm_all_reduce,
183
199
  mesh=mesh,
184
- in_specs=(P("data", "model"), P(None, None, "model"), P("data")),
185
- out_specs=(P("data")),
186
- check_rep=False,
187
- )(lhs, rhs, group_sizes)
188
- if rhs_bias is not None:
189
-
190
- def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
191
- rhs_bis = jnp.repeat(rhs_bias_local,
192
- group_sizes_global,
193
- 0,
194
- total_repeat_length=m // mesh.shape["data"])
195
- return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)
196
-
197
- gmm_result = shard_map(
198
- _add_bias,
199
- mesh=mesh,
200
- in_specs=(P("data"), P(), P("data")),
201
- out_specs=(P("data")),
202
- )(gmm_result, rhs_bias, group_sizes)
200
+ in_specs=(P(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR),
201
+ P(None, None, ShardingAxisName.MLP_TENSOR), rhs_scale_spec,
202
+ rhs_bias_spec, P(ShardingAxisName.MLP_DATA)),
203
+ out_specs=(P(ShardingAxisName.MLP_DATA)),
204
+ check_vma=False,
205
+ )(lhs, rhs, rhs_scale, rhs_bias, group_sizes)
203
206
 
204
- return gmm_result
207
+ return gmm_result.astype(lhs.dtype)
205
208
 
206
209
 
207
210
  def expert_sharded_gmm(
208
211
  lhs: jax.Array,
209
212
  rhs: jax.Array,
213
+ rhs_scale: jax.Array | None,
214
+ rhs_bias: jax.Array | None,
210
215
  group_sizes: jax.Array,
211
- transpose_rhs: bool,
216
+ is_last_expert: bool,
212
217
  mesh: Mesh,
213
- num_experts: int,
214
- ep_size: int,
215
218
  ) -> jax.Array:
216
- # adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
217
- m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
218
- n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
219
- tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
220
-
219
+ ep_size = get_mesh_shape_product(mesh, ShardingAxisName.MLP_TENSOR)
220
+ ep_p_spec = P(ShardingAxisName.EXPERT)
221
+ num_experts = rhs.shape[0]
221
222
  num_experts_per_shard = num_experts // ep_size
222
223
  group_offset = jnp.arange(0, num_experts, num_experts_per_shard)
223
- group_offset = jax.lax.with_sharding_constraint(
224
- group_offset, NamedSharding(mesh, P("model")))
225
-
226
- def _gmm(lhs, rhs, group_sizes, group_offset):
227
- # Group offset for this shard. `group_offset` is sharded, and in this
228
- # sharded function, it has only 1 element and `group_offset.shape` is
229
- # (1,) but gmm kernel requires the group_offset to be a ()-shaped array,
230
- # so we group_offset[0].
231
- group_offset_of_shard = group_offset[0]
224
+
225
+ def _gmm(lhs, rhs, rhs_scale, rhs_bias, group_sizes, group_offset):
226
+ m, g, n, k = lhs.shape[0], *rhs.shape
227
+ tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
228
+
232
229
  gmm_res = gmm(
233
230
  lhs=lhs,
234
231
  rhs=rhs,
232
+ rhs_scale=rhs_scale,
233
+ rhs_bias=rhs_bias,
235
234
  group_sizes=group_sizes,
236
235
  preferred_element_type=lhs.dtype,
237
236
  tiling=(tm, tk, tn),
238
- transpose_rhs=transpose_rhs,
239
- group_offset=group_offset_of_shard,
237
+ transpose_rhs=True,
238
+ group_offset=group_offset[0],
240
239
  )
241
240
  return gmm_res
242
241
 
@@ -258,35 +257,43 @@ def expert_sharded_gmm(
258
257
  # 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0 D, D, D, D
259
258
  # shard-0 shard-1 shard-2 shard-3
260
259
  # Each shards has 3 (row A), 2 (row B), 5 (row C) and 4 (row D).
261
- gmm_res = shard_map(
260
+ lhs_spec = ep_p_spec if is_last_expert else P()
261
+ rhs_spec = ep_p_spec
262
+ rhs_scale_spec = None if rhs_scale is None else ep_p_spec
263
+ rhs_bias_spec = None if rhs_bias is None else ep_p_spec
264
+ gmm_res = jax.shard_map(
262
265
  _gmm,
263
266
  mesh=mesh,
264
- in_specs=(P(), P("model", None, None), P(), P("model")),
265
- out_specs=(P("model", None)),
266
- check_rep=False,
267
- )(lhs, rhs, group_sizes, group_offset)
267
+ in_specs=(
268
+ lhs_spec,
269
+ rhs_spec,
270
+ rhs_scale_spec,
271
+ rhs_bias_spec,
272
+ P(),
273
+ ep_p_spec,
274
+ ),
275
+ out_specs=ep_p_spec,
276
+ check_vma=False,
277
+ )(lhs, rhs, rhs_scale, rhs_bias, group_sizes, group_offset)
278
+
279
+ if not is_last_expert:
280
+ return gmm_res
268
281
 
269
282
  # For i-th shard, it is responsible groups (AKA experts) from
270
283
  # i*num_experts_per_shard to (i+1)*num_experts_per_shard We sum them up to
271
284
  # get total rows in that shard, and that is the size for shard to send to
272
285
  # its peers. This is also the number of non-zero rows from the gmm results.
273
- # In the working example, send_sizes would be [3, 2, 5, 4]
274
- send_sizes = jnp.array([
275
- group_sizes[i * num_experts_per_shard:(i + 1) *
276
- num_experts_per_shard].sum() for i in range(ep_size)
277
- ])
286
+ # In the working example, send_sizes would be [3, 2, 5, 4].
287
+
288
+ # group_sizes has shape of [num_tokens_per_shard * num_experts_per_shard].
289
+ # So reshaping to [num_tokens_per_shard, num_experts_per_shard] and applying
290
+ # sum(axis=1) will get desired send_sizes shaped [num_tokens_per_shard].
291
+ send_sizes = group_sizes.reshape(-1, num_experts_per_shard).sum(axis=1)
278
292
  # In the working example, input_offsets would be [0, 3, 5, 10]
279
293
  input_offsets = jnp.concatenate((jnp.array([0]), send_sizes.cumsum()[:-1]))
280
294
  output_offsets = input_offsets
281
295
  recv_sizes = send_sizes
282
296
 
283
- input_offsets = jax.lax.with_sharding_constraint(
284
- input_offsets, NamedSharding(mesh, P("model")))
285
- send_sizes = jax.lax.with_sharding_constraint(
286
- send_sizes, NamedSharding(mesh, P("model")))
287
- output_offsets = jax.lax.with_sharding_constraint(
288
- output_offsets, NamedSharding(mesh, P("model")))
289
-
290
297
  def _ragged_all_to_all(operand, input_offsets, send_sizes, output_offsets,
291
298
  recv_sizes):
292
299
  output = jnp.zeros_like(operand)
@@ -317,7 +324,7 @@ def expert_sharded_gmm(
317
324
  send_sizes_of_shard,
318
325
  output_offsets_of_shard,
319
326
  recv_sizes_of_shard,
320
- axis_name="model")
327
+ axis_name=ShardingAxisName.EXPERT)
321
328
 
322
329
  # Use ragged_all_to_all to send the result from gmm for each expert to all
323
330
  # the shards. In the working example, the result would be:
@@ -336,56 +343,74 @@ def expert_sharded_gmm(
336
343
  # D, D, D, D D, D, D, D D, D, D, D D, D, D, D
337
344
  # D, D, D, D D, D, D, D D, D, D, D D, D, D, D
338
345
  # shard-0 shard-1 shard-2 shard-3
339
- return shard_map(
346
+ return jax.shard_map(
340
347
  _ragged_all_to_all,
341
348
  mesh=mesh,
342
- in_specs=(P("model", None), P("model"), P("model"), P("model"), P()),
343
- out_specs=(P()),
344
- check_rep=False,
349
+ in_specs=(ep_p_spec, ep_p_spec, ep_p_spec, ep_p_spec, P()),
350
+ out_specs=(P(ShardingAxisName.MLP_DATA)),
351
+ check_vma=False,
345
352
  )(gmm_res, input_offsets, send_sizes, output_offsets, recv_sizes)
346
353
 
347
354
 
355
+ @functools.partial(
356
+ jax.jit,
357
+ static_argnames=(
358
+ "topk",
359
+ "renormalize",
360
+ "mesh",
361
+ "use_ep",
362
+ "activation",
363
+ ),
364
+ )
348
365
  def fused_moe_func(
349
366
  hidden_states: jax.Array,
350
367
  w1: jax.Array,
351
368
  w2: jax.Array,
369
+ w1_scale: jax.Array | None,
370
+ w2_scale: jax.Array | None,
352
371
  w1_bias: jax.Array | None,
353
372
  w2_bias: jax.Array | None,
354
373
  gating_output: jax.Array,
355
374
  topk: int,
356
- global_num_experts: int,
357
375
  renormalize: bool,
358
- reduce_results: bool,
359
376
  mesh: Mesh,
360
377
  use_ep: bool,
361
378
  activation: str,
362
- ):
363
- """
379
+ ) -> jax.Array:
380
+ """Route tokens in hidden_states into each experts based on routing.
381
+
364
382
  Args:
365
- hidden_states: [*, hidden_size]
366
- w1: [num_experts, intermediate_size * 2, hidden_size]
367
- w2: [num_experts, hidden_size, intermediate_size]
368
- gating_output: [*, num_experts]
383
+ hidden_states: [num_tokens, hidden_size]
384
+ w1: first moe weights [num_experts, intermediate_size * 2, hidden_size]
385
+ w2: second moe weights [num_experts, hidden_size, intermediate_size]
386
+ w1_scale: w1 scale [num_experts, num_blocks, 1, intermediate_size * 2]
387
+ w2_scale: w2 scale [num_experts, num_blocks, 1, hidden_size]
388
+ w1_bias: optional bias of w1 [num_experts, 1, intermediate_size * 2]
389
+ w2_bias: optional bias of w2 [num_experts, 1, hidden_size]
390
+ gating_output: routing information of tokens [num_tokens, num_experts]
391
+ topk: number of experts to choose per token.
392
+ renormalize: normalize gating_output.
393
+ mesh: mesh to perform moe.
394
+ use_ep: use expert parallelism.
395
+ activation: activation function to perform on the output of w1.
396
+
397
+ Returns:
398
+ Output of moe operation [num_tokens, hidden_size]
369
399
  """
370
- # adapted from https://github.com/vllm-project/vllm/blob/29fa5cac1cd731026f59084d93a822921507573c/vllm/model_executor/layers/fused_moe/moe_pallas.py#L26
371
- if use_ep and (w1_bias is not None or w2_bias is not None):
372
- raise NotImplementedError(
373
- "Bias is not supported when using expert parallelism.")
374
- orig_shape = hidden_states.shape
375
- hidden_size = hidden_states.shape[-1]
376
- num_tokens = hidden_states.size // hidden_size
377
- assert global_num_experts == w1.shape[0]
378
- ep_size = mesh.shape["model"] # only used if use_ep is True.
379
- intermediate_size = w2.shape[-1]
400
+ num_tokens, hidden_size = hidden_states.shape
401
+ global_num_experts, _, padded_hidden_size = w1.shape
380
402
  dtype = hidden_states.dtype
403
+
381
404
  assert (num_tokens * topk) % 16 == 0, (
382
405
  "The kernel requires num_tokens * topk to be a multiple of "
383
406
  f"16 but got {num_tokens}*{topk}={num_tokens*topk}")
384
407
 
385
- hidden_states = hidden_states.reshape(num_tokens, hidden_size)
386
- gating_output = gating_output.reshape(num_tokens, global_num_experts)
408
+ assert gating_output.shape == (num_tokens, global_num_experts)
387
409
 
388
410
  topk_weights = jax.nn.softmax(gating_output.astype(jnp.float32), axis=-1)
411
+ # All-gather topk weights for attention dp
412
+ topk_weights = jax.lax.with_sharding_constraint(
413
+ topk_weights, NamedSharding(mesh, P(ShardingAxisName.MLP_DATA, None)))
389
414
  topk_weights, topk_indices = jax.lax.top_k(topk_weights, k=topk)
390
415
  if renormalize:
391
416
  topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdims=True)
@@ -405,152 +430,77 @@ def fused_moe_func(
405
430
  x = hidden_states_local[token_indices_sorted]
406
431
  return x, group_sizes_local, topk_argsort_revert_indices
407
432
 
408
- x, group_sizes, topk_argsort_revert_indices = shard_map(
433
+ x, group_sizes, topk_argsort_revert_indices = jax.shard_map(
409
434
  _process_tokens_locally,
410
435
  mesh=mesh,
411
- in_specs=(P("data", None), P("data", None)),
412
- out_specs=(P("data", None), P("data"), P("data")),
413
- check_rep=False,
436
+ in_specs=(P(ShardingAxisName.MLP_DATA,
437
+ None), P(ShardingAxisName.MLP_DATA, None)),
438
+ out_specs=(P(ShardingAxisName.MLP_DATA, None),
439
+ P(ShardingAxisName.MLP_DATA), P(ShardingAxisName.MLP_DATA)),
414
440
  )(hidden_states, topk_indices)
441
+
442
+ x = jnp.pad(x, ((0, 0), (0, padded_hidden_size - hidden_size)))
443
+
415
444
  if use_ep:
416
445
  x = expert_sharded_gmm(
417
446
  x,
418
447
  w1,
419
- group_sizes,
420
- transpose_rhs=True,
421
- mesh=mesh,
422
- num_experts=global_num_experts,
423
- ep_size=ep_size,
424
- )
425
- x1, x2 = x[..., :intermediate_size], x[..., intermediate_size:]
426
- else:
427
- x1, x2 = tensor_sharded_gmm_merged_column_parallel(
428
- x,
429
- w1,
448
+ w1_scale,
430
449
  w1_bias,
431
450
  group_sizes,
432
- transpose_rhs=True,
451
+ is_last_expert=False,
433
452
  mesh=mesh,
434
- intermediate_size=intermediate_size,
435
453
  )
454
+ x1, x2 = jnp.split(x, 2, -1)
436
455
 
437
- x = activation_fn(activation, x1, x2)
456
+ x = activation_fn(activation, x1, x2)
438
457
 
439
- if use_ep:
440
458
  x = expert_sharded_gmm(
441
459
  x,
442
460
  w2,
461
+ w2_scale,
462
+ w2_bias,
443
463
  group_sizes,
444
- transpose_rhs=True,
464
+ is_last_expert=True,
445
465
  mesh=mesh,
446
- num_experts=global_num_experts,
447
- ep_size=ep_size,
448
466
  )
449
467
  else:
450
- x = jax.lax.with_sharding_constraint(
451
- x, NamedSharding(mesh, P("data", "model")))
468
+ x1, x2 = tensor_sharded_gmm_merged_column_parallel(
469
+ x,
470
+ w1,
471
+ w1_scale,
472
+ w1_bias,
473
+ group_sizes,
474
+ mesh=mesh,
475
+ )
476
+
477
+ x = activation_fn(activation, x1, x2)
478
+
452
479
  x = tensor_sharded_gmm_row_parallel(
453
480
  x,
454
481
  w2,
482
+ w2_scale,
455
483
  w2_bias,
456
484
  group_sizes,
457
- transpose_rhs=True,
458
485
  mesh=mesh,
459
486
  )
460
487
 
461
488
  def _finalize_output(x_local, topk_argsort_revert_indices_local,
462
489
  topk_weights_local):
463
490
  x_local = x_local[topk_argsort_revert_indices_local].reshape(
464
- -1, topk, hidden_size)
491
+ -1, topk, padded_hidden_size)
465
492
  x_local = x_local * jnp.expand_dims(topk_weights_local, axis=-1)
466
493
  x_local = x_local.sum(axis=-2)
467
494
  return x_local
468
495
 
469
- x = shard_map(
496
+ x = jax.shard_map(
470
497
  _finalize_output,
471
498
  mesh=mesh,
472
- in_specs=(P("data", None), P("data"), P("data", None)),
473
- out_specs=(P("data", None)),
474
- check_rep=False,
499
+ in_specs=(P(ShardingAxisName.MLP_DATA,
500
+ None), P(ShardingAxisName.MLP_DATA),
501
+ P(ShardingAxisName.MLP_DATA, None)),
502
+ out_specs=(P(ShardingAxisName.ATTN_DATA, None)),
503
+ check_vma=False,
475
504
  )(x, topk_argsort_revert_indices, topk_weights)
476
- x = x.reshape(orig_shape)
477
505
 
478
- if reduce_results:
479
- x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P("data")))
480
- return x
481
-
482
-
483
- @functools.partial(
484
- jax.jit,
485
- static_argnames=(
486
- "topk",
487
- "global_num_experts",
488
- "renormalize",
489
- "reduce_results",
490
- "mesh",
491
- "use_ep",
492
- "activation",
493
- ),
494
- )
495
- def fused_moe_func_padded(
496
- hidden_states: jax.Array,
497
- w1: jax.Array,
498
- w2: jax.Array,
499
- w1_bias: jax.Array | None,
500
- w2_bias: jax.Array | None,
501
- gating_output: jax.Array,
502
- topk: int,
503
- global_num_experts: int,
504
- renormalize: bool,
505
- reduce_results: bool,
506
- mesh: Mesh,
507
- use_ep: bool,
508
- activation: str,
509
- ):
510
- # TODO(fanhongmin@google.com): Once the jax runner pads the input, we no longer need this.
511
- hidden_size = hidden_states.shape[-1]
512
- num_tokens = hidden_states.size // hidden_size
513
- if num_tokens * topk < 16:
514
- assert 16 % (num_tokens *
515
- topk) == 0, f"Cannot pad to 16: {num_tokens=}, {topk=}"
516
- n_repeats = 16 // (num_tokens * topk)
517
-
518
- reps = (n_repeats, ) + (1, ) * (hidden_states.ndim - 1)
519
- expanded_hidden_states = jnp.tile(hidden_states, reps)
520
-
521
- reps = (n_repeats, ) + (1, ) * (gating_output.ndim - 1)
522
- expanded_gating_output = jnp.tile(gating_output, reps)
523
-
524
- expanded_x = fused_moe_func(
525
- expanded_hidden_states,
526
- w1,
527
- w2,
528
- w1_bias,
529
- w2_bias,
530
- expanded_gating_output,
531
- topk,
532
- global_num_experts,
533
- renormalize,
534
- reduce_results,
535
- mesh,
536
- use_ep,
537
- activation,
538
- )
539
- x = expanded_x[:hidden_states.shape[0]]
540
- return x
541
- else:
542
- return fused_moe_func(
543
- hidden_states,
544
- w1,
545
- w2,
546
- w1_bias,
547
- w2_bias,
548
- gating_output,
549
- topk,
550
- global_num_experts,
551
- renormalize,
552
- reduce_results,
553
- mesh,
554
- use_ep,
555
- activation,
556
- )
506
+ return x[:num_tokens, :hidden_size]