tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import functools
2
16
  from dataclasses import dataclass
3
17
  from typing import Optional
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  from dataclasses import dataclass
2
16
  from typing import Any, Optional, Tuple
3
17
 
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -15,6 +15,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
15
15
  from tpu_inference import utils
16
16
  from tpu_inference.layers.common.attention_interface import attention
17
17
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
18
+ from tpu_inference.layers.common.quantization import quantize_kv
18
19
  from tpu_inference.logger import init_logger
19
20
  from tpu_inference.models.vllm.vllm_model_wrapper_context import \
20
21
  get_vllm_model_wrapper_context
@@ -117,10 +118,9 @@ class PallasAttentionBackendImpl(AttentionImpl):
117
118
  query, key, value = jax_view(query), jax_view(key), jax_view(value)
118
119
  q_scale = k_scale = v_scale = None
119
120
  if self.kv_cache_quantized_dtype:
120
- key, value = utils.quantize_kv(key, value,
121
- self.kv_cache_quantized_dtype,
122
- layer._k_scale_float,
123
- layer._v_scale_float)
121
+ key, value = quantize_kv(self.kv_cache_quantized_dtype, key, value,
122
+ layer._k_scale_float,
123
+ layer._v_scale_float)
124
124
  # TODO(kyuyeunk): Enable w8a8 when VREG spill issue is resolved.
125
125
  # q_scale = layer._q_scale_float
126
126
  k_scale = layer._k_scale_float
@@ -1,469 +1,114 @@
1
- import functools
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from enum import Enum
2
16
 
3
17
  import jax
4
- from jax import numpy as jnp
5
- from jax import shard_map
6
- from jax.experimental.pallas.ops.tpu.megablox.gmm import gmm
18
+ import jax.numpy as jnp
19
+ import torch
7
20
  from jax.sharding import Mesh
8
- from jax.sharding import PartitionSpec as P
21
+ from torchax.interop import jax_view, torch_view
22
+ from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
23
+ from vllm.model_executor.layers.fused_moe.layer import FusedMoE
9
24
 
10
- from tpu_inference.layers.vllm.linear_common import \
11
- slice_sharded_tensor_for_concatenation
25
+ from tpu_inference import envs
26
+ from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
27
+ from tpu_inference.layers.common.fused_moe_gmm import fused_moe_func
28
+ from tpu_inference.logger import init_logger
12
29
 
30
+ logger = init_logger(__name__)
13
31
 
14
- def activation_fn(activation: str, x1: jax.Array, x2: jax.Array) -> jax.Array:
15
- match activation:
16
- case "silu":
17
- return jax.nn.silu(x1) * x2
18
- case "swigluoai":
19
- return _swigluoai(x1, x2)
20
- case _:
21
- raise NotImplementedError(
22
- f"FusedMoE does not support {activation} activation")
23
32
 
33
+ class FusedMoEBackend(Enum):
34
+ FUSED_MOE = "fused_moe"
35
+ GMM_EP = "gmm_ep"
36
+ GMM_TP = "gmm_tp"
24
37
 
25
- def _swigluoai(x1: jax.Array,
26
- x2: jax.Array,
27
- alpha=1.702,
28
- limit=7.0) -> jax.Array:
29
- x1 = jnp.clip(x1, a_max=limit)
30
- x2 = jnp.clip(x2, a_min=-limit, a_max=limit)
31
38
 
32
- gated_activation = x1 * jax.nn.sigmoid(alpha * x1)
39
+ def select_moe_backend(moe: FusedMoEConfig):
40
+ if envs.USE_MOE_EP_KERNEL:
41
+ if moe.use_ep:
42
+ return FusedMoEBackend.FUSED_MOE
43
+ logger.warning_once(
44
+ "USE_MOE_EP_KERNEL=1 but expert parallelism is not "
45
+ "enabled. Falling back to gmm implementation.")
33
46
 
34
- return gated_activation * (x2 + 1)
47
+ if moe.use_ep:
48
+ return FusedMoEBackend.GMM_EP
35
49
 
50
+ # Use default implementation.
51
+ return FusedMoEBackend.GMM_TP
36
52
 
37
- def _round_up_to_multiple_of_128_within_limit(x: int, limit: int) -> int:
38
- """
39
- Rounds the given integer `x` up to the nearest multiple of 128, without
40
- exceeding the specified `limit`.
41
53
 
42
- If `x` is less than or equal to 128, returns 128.
43
- If `x` is less than `limit`, returns the smallest multiple of 128 greater
44
- than or equal to `x`.
45
- If `x` is greater than or equal to `limit`, searches for the largest
46
- multiple of 128 less than or equal to `limit` (down to 512) that divides `x`
47
- evenly, and returns it.
48
- If no such candidate is found, returns `limit`.
49
-
50
- Args:
51
- x (int): The integer to round up.
52
- limit (int): The upper bound (must be a multiple of 128).
53
-
54
- Returns:
55
- int: The rounded value according to the rules above.
56
-
57
- Raises:
58
- AssertionError: If `limit` is less than 128 or not a multiple of 128.
59
- """
60
- assert limit >= 128 and limit % 128 == 0
61
- if x <= 128:
62
- return 128
63
- if x < limit:
64
- return (x + 127) // 128 * 128
65
- for candidate in range(limit, 511, -128):
66
- if x % candidate == 0:
67
- return candidate
68
- return limit
69
-
70
-
71
- def _get_tiling_size_for_gmm_kernel(m: int, k: int, n: int,
72
- g: int) -> tuple[int, int, int]:
73
- """
74
- Calculate optimal tiling sizes for a GMM kernel in a Mixture of Experts
75
- (MoE) setting.
76
-
77
- Args:
78
- m (int): The total number of tokens.
79
- n (int): The output feature dimension.
80
- k (int): The input feature dimension.
81
- g (int): The number of experts.
82
-
83
- Returns:
84
- tuple[int, int, int]: A tuple (tm, tk, tn)
85
- """
86
-
87
- # TODO(Chengji): increase the upper limit tiling size of m when we can set
88
- # the vmem size to be used for gmm kernel.
89
- # NOTE: In average each expert has m // g tokens, but as it might be
90
- # unbalanced, here we doubled the token size when choosing tiling size of m.
91
- # 2m//g can be either greater or less than 512. If there are 32 tokens and
92
- # topk=2, m=topk * num_tokens=64, in this case, 2*m//g will be less than
93
- # 512.
94
- tm = _round_up_to_multiple_of_128_within_limit(2 * m // g, 512)
95
- tm = min(tm, m) # there's a requirement that m % tm == 0
96
- # k/n correspond to n_input_features/n_output_features in the matmul so they
97
- # are normally greater than 2048, unless the num shards is large.
98
- tk = _round_up_to_multiple_of_128_within_limit(k, 2048)
99
- tn = _round_up_to_multiple_of_128_within_limit(n, 2048)
100
- return tm, tk, tn
101
-
102
-
103
- def tensor_sharded_gmm_merged_column_parallel(
104
- lhs: jax.Array,
105
- rhs: jax.Array,
106
- rhs_bias: jax.Array | None,
107
- group_sizes: jax.Array,
108
- mesh: Mesh,
109
- ) -> tuple[jax.Array, jax.Array]:
110
-
111
- def _gmm(lhs, rhs, group_sizes):
112
- m, g, n, k = lhs.shape[0], *rhs.shape
113
- tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
114
- return gmm(
115
- lhs,
116
- rhs,
117
- group_sizes,
118
- preferred_element_type=lhs.dtype,
119
- tiling=(tm, tk, tn),
120
- transpose_rhs=True,
121
- group_offset=jnp.array(0),
122
- )
123
-
124
- gmm_result = shard_map(
125
- _gmm,
126
- mesh=mesh,
127
- in_specs=(P("data", None), P(None, "model", None), P("data")),
128
- out_specs=(P("data", "model")),
129
- check_vma=False,
130
- )(lhs, rhs, group_sizes)
131
-
132
- if rhs_bias is not None:
133
-
134
- def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
135
- rhs_bias = jnp.repeat(
136
- rhs_bias_local,
137
- group_sizes_global,
138
- 0,
139
- total_repeat_length=gmm_result_local.shape[0])
140
- return gmm_result_local + rhs_bias
141
-
142
- gmm_result = shard_map(
143
- _add_bias,
144
- mesh=mesh,
145
- in_specs=(P("data", "model"), P(None, "model"), P("data")),
146
- out_specs=(P("data", "model")),
147
- )(gmm_result, rhs_bias, group_sizes)
148
- gmm_result = gmm_result.astype(lhs.dtype)
149
-
150
- tp_size = mesh.shape["model"]
151
- intermediate_size = gmm_result.shape[-1] // 2
152
- output_sizes = [intermediate_size, intermediate_size]
153
- return slice_sharded_tensor_for_concatenation(gmm_result, output_sizes,
154
- tp_size)
155
-
156
-
157
- def tensor_sharded_gmm_row_parallel(
158
- lhs: jax.Array,
159
- rhs: jax.Array,
160
- rhs_bias: jax.Array | None,
161
- group_sizes: jax.Array,
162
- mesh: Mesh,
163
- ) -> jax.Array:
164
-
165
- def _gmm_all_reduce(lhs, rhs, group_sizes):
166
- m, g, n, k = lhs.shape[0], *rhs.shape
167
- tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
168
- out = gmm(
169
- lhs,
170
- rhs,
171
- group_sizes,
172
- preferred_element_type=lhs.dtype,
173
- tiling=(tm, tk, tn),
174
- transpose_rhs=True,
175
- group_offset=jnp.array(0),
176
- )
177
- return jax.lax.psum(out, axis_name="model")
178
-
179
- gmm_result = shard_map(
180
- _gmm_all_reduce,
181
- mesh=mesh,
182
- in_specs=(P("data", "model"), P(None, None, "model"), P("data")),
183
- out_specs=(P("data")),
184
- check_vma=False,
185
- )(lhs, rhs, group_sizes)
186
-
187
- if rhs_bias is not None:
188
-
189
- def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
190
- rhs_bias = jnp.repeat(
191
- rhs_bias_local,
192
- group_sizes_global,
193
- 0,
194
- total_repeat_length=gmm_result_local.shape[0])
195
- return gmm_result_local + rhs_bias
196
-
197
- gmm_result = shard_map(
198
- _add_bias,
199
- mesh=mesh,
200
- in_specs=(P("data"), P(), P("data")),
201
- out_specs=(P("data")),
202
- )(gmm_result, rhs_bias, group_sizes)
203
-
204
- return gmm_result.astype(lhs.dtype)
205
-
206
-
207
- def expert_sharded_gmm(
208
- lhs: jax.Array,
209
- rhs: jax.Array,
210
- group_sizes: jax.Array,
211
- mesh: Mesh,
212
- ) -> jax.Array:
213
- ep_size = mesh.shape["model"]
214
-
215
- num_experts = rhs.shape[0]
216
- num_experts_per_shard = num_experts // ep_size
217
- group_offset = jnp.arange(0, num_experts, num_experts_per_shard)
218
-
219
- def _gmm(lhs, rhs, group_sizes, group_offset):
220
- m, g, n, k = lhs.shape[0], *rhs.shape
221
- tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
222
-
223
- gmm_res = gmm(
224
- lhs=lhs,
225
- rhs=rhs,
226
- group_sizes=group_sizes,
227
- preferred_element_type=lhs.dtype,
228
- tiling=(tm, tk, tn),
229
- transpose_rhs=True,
230
- group_offset=group_offset[0],
231
- )
232
- return gmm_res
233
-
234
- # The result from gmm on each shard has the same shape, but only the rows
235
- # for this shard has non-zero values. Taking below as an working example:
236
- # A, A, A, A 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0
237
- # A, A, A, A 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0
238
- # A, A, A, A 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0
239
- # 0, 0, 0, 0 B, B, B, B 0, 0, 0, 0 0, 0, 0, 0
240
- # 0, 0, 0, 0 B, B, B, B 0, 0, 0, 0 0, 0, 0, 0
241
- # 0, 0, 0, 0 0, 0, 0, 0 C, C, C, C 0, 0, 0, 0
242
- # 0, 0, 0, 0 0, 0, 0, 0 C, C, C, C 0, 0, 0, 0
243
- # 0, 0, 0, 0 0, 0, 0, 0 C, C, C, C 0, 0, 0, 0
244
- # 0, 0, 0, 0 0, 0, 0, 0 C, C, C, C 0, 0, 0, 0
245
- # 0, 0, 0, 0 0, 0, 0, 0 C, C, C, C 0, 0, 0, 0
246
- # 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0 D, D, D, D
247
- # 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0 D, D, D, D
248
- # 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0 D, D, D, D
249
- # 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0 D, D, D, D
250
- # shard-0 shard-1 shard-2 shard-3
251
- # Each shards has 3 (row A), 2 (row B), 5 (row C) and 4 (row D).
252
- gmm_res = shard_map(
253
- _gmm,
254
- mesh=mesh,
255
- in_specs=(P(), P("model", None, None), P(), P("model")),
256
- out_specs=(P("model", None)),
257
- check_vma=False,
258
- )(lhs, rhs, group_sizes, group_offset)
259
-
260
- # For i-th shard, it is responsible groups (AKA experts) from
261
- # i*num_experts_per_shard to (i+1)*num_experts_per_shard We sum them up to
262
- # get total rows in that shard, and that is the size for shard to send to
263
- # its peers. This is also the number of non-zero rows from the gmm results.
264
- # In the working example, send_sizes would be [3, 2, 5, 4].
265
-
266
- # group_sizes has shape of [num_tokens_per_shard * num_experts_per_shard].
267
- # So reshaping to [num_tokens_per_shard, num_experts_per_shard] and applying
268
- # sum(axis=1) will get desired send_sizes shaped [num_tokens_per_shard].
269
- send_sizes = group_sizes.reshape(-1, num_experts_per_shard).sum(axis=1)
270
- # In the working example, input_offsets would be [0, 3, 5, 10]
271
- input_offsets = jnp.concatenate((jnp.array([0]), send_sizes.cumsum()[:-1]))
272
- output_offsets = input_offsets
273
- recv_sizes = send_sizes
274
-
275
- def _ragged_all_to_all(operand, input_offsets, send_sizes, output_offsets,
276
- recv_sizes):
277
- output = jnp.zeros_like(operand)
278
-
279
- # input_offsets, send_sizes and output_offsets are sharded and there is
280
- # only 1 elemnt in each shard, we are taking the 0-th element from them
281
- # just so that jnp.repeat generates the arrays with correct shape.
282
- input_offsets_of_shard = jnp.repeat(input_offsets[0], ep_size)
283
- send_sizes_of_shard = jnp.repeat(send_sizes[0], ep_size)
284
- output_offsets_of_shard = jnp.repeat(output_offsets[0], ep_size)
285
-
286
- # recv_sizes is replicated across shards, because all the shards receive
287
- # the same data and write to the output in the same way (same
288
- # output_offsets and same recv_sizes) and thus generates replicated
289
- # output.
290
- recv_sizes_of_shard = recv_sizes
291
-
292
- # In the working example, for each shard, the values of the offsets and
293
- # sizes would be:
294
- # shard-0 shard-1 shard-2 shard-3
295
- # input_offsets_of_shard [0, 0, 0, 0] [3, 3, 3, 3] [5, 5, 5, 5] [10,10,10,10]
296
- # send_sizes_of_shard [3, 3, 3, 3] [2, 2, 2, 2] [5, 5, 5, 5] [4, 4, 4, 4 ]
297
- # output_offsets_of_shard [0, 0, 0, 0] [0, 0, 0, 0] [0, 0, 0, 0] [10,10,10,10]
298
- # recv_sizes_of_shard [3, 2, 5, 4] [3, 2, 5, 4] [3, 2, 5, 4] [3, 2, 5, 4]
299
- return jax.lax.ragged_all_to_all(operand,
300
- output,
301
- input_offsets_of_shard,
302
- send_sizes_of_shard,
303
- output_offsets_of_shard,
304
- recv_sizes_of_shard,
305
- axis_name="model")
306
-
307
- # Use ragged_all_to_all to send the result from gmm for each expert to all
308
- # the shards. In the working example, the result would be:
309
- # A, A, A, A A, A, A, A A, A, A, A A, A, A, A
310
- # A, A, A, A A, A, A, A A, A, A, A A, A, A, A
311
- # A, A, A, A A, A, A, A A, A, A, A A, A, A, A
312
- # B, B, B, B B, B, B, B B, B, B, B B, B, B, B
313
- # B, B, B, B B, B, B, B B, B, B, B B, B, B, B
314
- # C, C, C, C C, C, C, C C, C, C, C C, C, C, C
315
- # C, C, C, C C, C, C, C C, C, C, C C, C, C, C
316
- # C, C, C, C C, C, C, C C, C, C, C C, C, C, C
317
- # C, C, C, C C, C, C, C C, C, C, C C, C, C, C
318
- # C, C, C, C C, C, C, C C, C, C, C C, C, C, C
319
- # D, D, D, D D, D, D, D D, D, D, D D, D, D, D
320
- # D, D, D, D D, D, D, D D, D, D, D D, D, D, D
321
- # D, D, D, D D, D, D, D D, D, D, D D, D, D, D
322
- # D, D, D, D D, D, D, D D, D, D, D D, D, D, D
323
- # shard-0 shard-1 shard-2 shard-3
324
- return shard_map(
325
- _ragged_all_to_all,
326
- mesh=mesh,
327
- in_specs=(P("model", None), P("model"), P("model"), P("model"), P()),
328
- out_specs=(P()),
329
- check_vma=False,
330
- )(gmm_res, input_offsets, send_sizes, output_offsets, recv_sizes)
331
-
332
-
333
- @functools.partial(
334
- jax.jit,
335
- static_argnames=(
336
- "topk",
337
- "renormalize",
338
- "mesh",
339
- "use_ep",
340
- "activation",
341
- ),
342
- )
343
- def fused_moe_func(
344
- hidden_states: jax.Array,
345
- w1: jax.Array,
346
- w2: jax.Array,
347
- w1_bias: jax.Array | None,
348
- w2_bias: jax.Array | None,
349
- gating_output: jax.Array,
350
- topk: int,
351
- renormalize: bool,
54
+ def fused_moe_apply(
55
+ layer: torch.nn.Module,
56
+ x: torch.Tensor,
57
+ router_logits: torch.Tensor,
58
+ moe_backend: FusedMoEBackend,
352
59
  mesh: Mesh,
353
- use_ep: bool,
354
- activation: str,
355
- ) -> jax.Array:
356
- """
357
- Route tokens in hidden_states into each experts based on routing
358
- information in gating_out and performs moe with w1 and w2 weights.
359
-
360
- Args:
361
- hidden_states: [num_tokens, hidden_size]
362
- w1: first moe weights [num_experts, intermediate_size * 2, hidden_size]
363
- w2: second moe weights [num_experts, hidden_size, intermediate_size]
364
- w1_bias: optional bias of w1 [num_experts, intermediate_size * 2]
365
- w2_bias: optional bias of w2 [num_experts, hidden_size]
366
- gating_output: routing information of tokens [num_tokens, num_experts]
367
- topk: number of experts to choose per token.
368
- renormalize: normalize gating_output.
369
- mesh: mesh to perform moe.
370
- use_ep: use expert parallelism.
371
- activation: activation function to perform on the output of w1.
372
-
373
- Returns:
374
- Output of moe operation [num_tokens, hidden_size]
375
- """
376
- if use_ep and (w1_bias is not None or w2_bias is not None):
377
- raise NotImplementedError(
378
- "Bias is not supported when using expert parallelism.")
379
-
380
- num_tokens = hidden_states.shape[0]
381
- global_num_experts, hidden_size, intermediate_size = w2.shape
382
- dtype = hidden_states.dtype
383
-
384
- assert (num_tokens * topk) % 16 == 0, (
385
- "The kernel requires num_tokens * topk to be a multiple of "
386
- f"16 but got {num_tokens}*{topk}={num_tokens*topk}")
387
- assert hidden_states.shape == (num_tokens, hidden_size)
388
- assert gating_output.shape == (num_tokens, global_num_experts)
389
- assert w1.shape == (global_num_experts, intermediate_size * 2, hidden_size)
390
-
391
- topk_weights = jax.nn.softmax(gating_output.astype(jnp.float32), axis=-1)
392
- topk_weights, topk_indices = jax.lax.top_k(topk_weights, k=topk)
393
- if renormalize:
394
- topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdims=True)
395
- topk_weights = topk_weights.astype(dtype)
396
-
397
- def _process_tokens_locally(hidden_states_local, topk_indices_local):
398
- num_tokens_local = hidden_states_local.shape[0]
399
- topk_indices_flat = topk_indices_local.flatten()
400
- topk_argsort_indices = jnp.argsort(topk_indices_flat)
401
- topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
402
- token_indices = jnp.arange(num_tokens_local,
403
- dtype=jnp.int32).repeat(topk)
404
- token_indices_sorted = token_indices[topk_argsort_indices]
405
- group_sizes_local = jnp.bincount(topk_indices_flat,
406
- length=global_num_experts)
407
-
408
- x = hidden_states_local[token_indices_sorted]
409
- return x, group_sizes_local, topk_argsort_revert_indices
410
-
411
- x, group_sizes, topk_argsort_revert_indices = shard_map(
412
- _process_tokens_locally,
413
- mesh=mesh,
414
- in_specs=(P("data", None), P("data", None)),
415
- out_specs=(P("data", None), P("data"), P("data")),
416
- )(hidden_states, topk_indices)
417
-
418
- if use_ep:
419
- x = expert_sharded_gmm(
420
- x,
421
- w1,
422
- group_sizes,
423
- mesh=mesh,
424
- )
425
- x1, x2 = jnp.split(x, 2, -1)
426
-
427
- x = activation_fn(activation, x1, x2)
428
-
429
- x = expert_sharded_gmm(
430
- x,
431
- w2,
432
- group_sizes,
433
- mesh=mesh,
434
- )
435
- else:
436
- x1, x2 = tensor_sharded_gmm_merged_column_parallel(
437
- x,
438
- w1,
439
- w1_bias,
440
- group_sizes,
441
- mesh=mesh,
442
- )
443
-
444
- x = activation_fn(activation, x1, x2)
445
-
446
- x = tensor_sharded_gmm_row_parallel(
447
- x,
448
- w2,
449
- w2_bias,
450
- group_sizes,
451
- mesh=mesh,
452
- )
453
-
454
- def _finalize_output(x_local, topk_argsort_revert_indices_local,
455
- topk_weights_local):
456
- x_local = x_local[topk_argsort_revert_indices_local].reshape(
457
- -1, topk, hidden_size)
458
- x_local = x_local * jnp.expand_dims(topk_weights_local, axis=-1)
459
- x_local = x_local.sum(axis=-2)
460
- return x_local
461
-
462
- x = shard_map(
463
- _finalize_output,
464
- mesh=mesh,
465
- in_specs=(P("data", None), P("data"), P("data", None)),
466
- out_specs=(P("data", None)),
467
- )(x, topk_argsort_revert_indices, topk_weights)
468
-
469
- return x[:num_tokens, :hidden_size]
60
+ extra_backend_kwargs: dict,
61
+ ) -> torch.Tensor:
62
+ assert isinstance(layer, FusedMoE)
63
+ if layer.scoring_func != "softmax":
64
+ raise NotImplementedError("Only softmax is supported for scoring_func")
65
+
66
+ x = jax_view(x)
67
+ gating_output = jax_view(router_logits)
68
+
69
+ w13_weight = jax_view(layer.w13_weight)
70
+ w13_weight_scale = jax_view(getattr(layer, "w13_weight_scale", None))
71
+ w13_bias = jax_view(getattr(layer, "w13_bias", None))
72
+ w2_weight = jax_view(layer.w2_weight)
73
+ w2_weight_scale = jax_view(getattr(layer, "w2_weight_scale", None))
74
+ w2_bias = jax_view(getattr(layer, "w2_bias", None))
75
+
76
+ with jax.named_scope(layer._get_name()):
77
+ match moe_backend:
78
+ case FusedMoEBackend.FUSED_MOE:
79
+ actual_hidden_size = x.shape[-1]
80
+ padding_size = w13_weight.shape[-2] - actual_hidden_size
81
+ x = jnp.pad(x, ((0, 0), (0, padding_size)))
82
+ output = fused_ep_moe(
83
+ mesh=mesh,
84
+ tokens=x,
85
+ w1=w13_weight,
86
+ w2=w2_weight,
87
+ w1_scale=w13_weight_scale,
88
+ w2_scale=w2_weight_scale,
89
+ b1=w13_bias,
90
+ b2=w2_bias,
91
+ gating_output=gating_output,
92
+ top_k=layer.top_k,
93
+ renormalize_topk_logits=layer.renormalize,
94
+ act_fn=layer.activation,
95
+ **extra_backend_kwargs,
96
+ )[:, :actual_hidden_size]
97
+ case FusedMoEBackend.GMM_EP | FusedMoEBackend.GMM_TP:
98
+ output = fused_moe_func(
99
+ hidden_states=x,
100
+ w1=w13_weight,
101
+ w2=w2_weight,
102
+ w1_scale=w13_weight_scale,
103
+ w2_scale=w2_weight_scale,
104
+ w1_bias=w13_bias,
105
+ w2_bias=w2_bias,
106
+ gating_output=gating_output,
107
+ topk=layer.top_k,
108
+ renormalize=layer.renormalize,
109
+ mesh=mesh,
110
+ use_ep=layer.use_ep,
111
+ activation=layer.activation,
112
+ )
113
+
114
+ return torch_view(output)