tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +46 -17
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +44 -17
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
- tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
|
@@ -1,18 +1,32 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
1
15
|
import functools
|
|
2
16
|
|
|
3
17
|
import jax
|
|
4
18
|
from jax import numpy as jnp
|
|
5
|
-
from jax.
|
|
6
|
-
from jax.
|
|
7
|
-
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
19
|
+
from jax.sharding import Mesh, NamedSharding
|
|
20
|
+
from jax.sharding import PartitionSpec as P
|
|
8
21
|
|
|
22
|
+
from tpu_inference.kernels.megablox.gmm import gmm
|
|
23
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
9
24
|
from tpu_inference.layers.vllm.linear_common import \
|
|
10
25
|
slice_sharded_tensor_for_concatenation
|
|
11
|
-
|
|
12
|
-
P = PartitionSpec
|
|
26
|
+
from tpu_inference.utils import get_mesh_shape_product
|
|
13
27
|
|
|
14
28
|
|
|
15
|
-
def activation_fn(activation: str, x1, x2):
|
|
29
|
+
def activation_fn(activation: str, x1: jax.Array, x2: jax.Array) -> jax.Array:
|
|
16
30
|
match activation:
|
|
17
31
|
case "silu":
|
|
18
32
|
return jax.nn.silu(x1) * x2
|
|
@@ -23,7 +37,10 @@ def activation_fn(activation: str, x1, x2):
|
|
|
23
37
|
f"FusedMoE does not support {activation} activation")
|
|
24
38
|
|
|
25
39
|
|
|
26
|
-
def _swigluoai(x1
|
|
40
|
+
def _swigluoai(x1: jax.Array,
|
|
41
|
+
x2: jax.Array,
|
|
42
|
+
alpha=1.702,
|
|
43
|
+
limit=7.0) -> jax.Array:
|
|
27
44
|
x1 = jnp.clip(x1, a_max=limit)
|
|
28
45
|
x2 = jnp.clip(x2, a_min=-limit, a_max=limit)
|
|
29
46
|
|
|
@@ -101,142 +118,124 @@ def _get_tiling_size_for_gmm_kernel(m: int, k: int, n: int,
|
|
|
101
118
|
def tensor_sharded_gmm_merged_column_parallel(
|
|
102
119
|
lhs: jax.Array,
|
|
103
120
|
rhs: jax.Array,
|
|
121
|
+
rhs_scale: jax.Array | None,
|
|
104
122
|
rhs_bias: jax.Array | None,
|
|
105
123
|
group_sizes: jax.Array,
|
|
106
|
-
transpose_rhs: bool,
|
|
107
124
|
mesh: Mesh,
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
+
) -> list[jax.Array]:
|
|
126
|
+
|
|
127
|
+
def _gmm(lhs, rhs, rhs_scale, rhs_bias, group_sizes):
|
|
128
|
+
m, g, n, k = lhs.shape[0], *rhs.shape
|
|
129
|
+
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
|
|
130
|
+
return gmm(
|
|
131
|
+
lhs,
|
|
132
|
+
rhs,
|
|
133
|
+
group_sizes,
|
|
134
|
+
rhs_scale=rhs_scale,
|
|
135
|
+
rhs_bias=rhs_bias,
|
|
136
|
+
preferred_element_type=lhs.dtype,
|
|
137
|
+
tiling=(tm, tk, tn),
|
|
138
|
+
transpose_rhs=True,
|
|
139
|
+
group_offset=jnp.array(0),
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
rhs_scale_spec = None if rhs_scale is None else P(
|
|
143
|
+
None, None, None, ShardingAxisName.MLP_TENSOR)
|
|
144
|
+
rhs_bias_spec = None if rhs_bias is None else P(
|
|
145
|
+
None, None, ShardingAxisName.MLP_TENSOR)
|
|
146
|
+
|
|
147
|
+
gmm_result = jax.shard_map(
|
|
125
148
|
_gmm,
|
|
126
149
|
mesh=mesh,
|
|
127
|
-
in_specs=(P(
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
0,
|
|
138
|
-
total_repeat_length=m // mesh.shape["data"])
|
|
139
|
-
return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)
|
|
140
|
-
|
|
141
|
-
gmm_result = shard_map(
|
|
142
|
-
_add_bias,
|
|
143
|
-
mesh=mesh,
|
|
144
|
-
in_specs=(P("data", "model"), P(None, "model"), P("data")),
|
|
145
|
-
out_specs=(P("data", "model")),
|
|
146
|
-
)(gmm_result, rhs_bias, group_sizes)
|
|
147
|
-
|
|
148
|
-
n_shards = mesh.shape["model"]
|
|
150
|
+
in_specs=(P(ShardingAxisName.MLP_DATA,
|
|
151
|
+
None), P(None, ShardingAxisName.MLP_TENSOR,
|
|
152
|
+
None), rhs_scale_spec, rhs_bias_spec,
|
|
153
|
+
P(ShardingAxisName.MLP_DATA)),
|
|
154
|
+
out_specs=(P(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR)),
|
|
155
|
+
check_vma=False,
|
|
156
|
+
)(lhs, rhs, rhs_scale, rhs_bias, group_sizes)
|
|
157
|
+
|
|
158
|
+
tp_size = get_mesh_shape_product(mesh, ShardingAxisName.MLP_TENSOR)
|
|
159
|
+
intermediate_size = gmm_result.shape[-1] // 2
|
|
149
160
|
output_sizes = [intermediate_size, intermediate_size]
|
|
150
|
-
|
|
151
161
|
return slice_sharded_tensor_for_concatenation(gmm_result, output_sizes,
|
|
152
|
-
|
|
162
|
+
tp_size)
|
|
153
163
|
|
|
154
164
|
|
|
155
165
|
def tensor_sharded_gmm_row_parallel(
|
|
156
166
|
lhs: jax.Array,
|
|
157
167
|
rhs: jax.Array,
|
|
168
|
+
rhs_scale: jax.Array | None,
|
|
158
169
|
rhs_bias: jax.Array | None,
|
|
159
170
|
group_sizes: jax.Array,
|
|
160
|
-
transpose_rhs: bool,
|
|
161
171
|
mesh: Mesh,
|
|
162
172
|
) -> jax.Array:
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
gmm
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
173
|
+
|
|
174
|
+
def _gmm_all_reduce(lhs, rhs, rhs_scale, rhs_bias, group_sizes):
|
|
175
|
+
m, g, n, k = lhs.shape[0], *rhs.shape
|
|
176
|
+
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
|
|
177
|
+
if rhs_bias is not None:
|
|
178
|
+
shard_id = jax.lax.axis_index(ShardingAxisName.MLP_TENSOR).sum()
|
|
179
|
+
rhs_bias = jnp.where(shard_id == 0, rhs_bias, 0)
|
|
180
|
+
out = gmm(
|
|
181
|
+
lhs,
|
|
182
|
+
rhs,
|
|
183
|
+
group_sizes,
|
|
184
|
+
rhs_scale=rhs_scale,
|
|
185
|
+
rhs_bias=rhs_bias,
|
|
186
|
+
preferred_element_type=lhs.dtype,
|
|
187
|
+
tiling=(tm, tk, tn),
|
|
188
|
+
transpose_rhs=True,
|
|
189
|
+
group_offset=jnp.array(0),
|
|
190
|
+
)
|
|
191
|
+
return jax.lax.psum(out, axis_name=ShardingAxisName.MLP_TENSOR)
|
|
192
|
+
|
|
193
|
+
num_blocks = 1 if rhs_scale is None else rhs_scale.shape[1]
|
|
194
|
+
rhs_scale_spec = None if num_blocks == 1 else P(
|
|
195
|
+
None, ShardingAxisName.MLP_TENSOR, None, None)
|
|
196
|
+
rhs_bias_spec = None if rhs_bias is None else P(None, None, None)
|
|
197
|
+
gmm_result = jax.shard_map(
|
|
182
198
|
_gmm_all_reduce,
|
|
183
199
|
mesh=mesh,
|
|
184
|
-
in_specs=(P(
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
|
|
191
|
-
rhs_bis = jnp.repeat(rhs_bias_local,
|
|
192
|
-
group_sizes_global,
|
|
193
|
-
0,
|
|
194
|
-
total_repeat_length=m // mesh.shape["data"])
|
|
195
|
-
return (gmm_result_local + rhs_bis).astype(gmm_result_local.dtype)
|
|
196
|
-
|
|
197
|
-
gmm_result = shard_map(
|
|
198
|
-
_add_bias,
|
|
199
|
-
mesh=mesh,
|
|
200
|
-
in_specs=(P("data"), P(), P("data")),
|
|
201
|
-
out_specs=(P("data")),
|
|
202
|
-
)(gmm_result, rhs_bias, group_sizes)
|
|
200
|
+
in_specs=(P(ShardingAxisName.MLP_DATA, ShardingAxisName.MLP_TENSOR),
|
|
201
|
+
P(None, None, ShardingAxisName.MLP_TENSOR), rhs_scale_spec,
|
|
202
|
+
rhs_bias_spec, P(ShardingAxisName.MLP_DATA)),
|
|
203
|
+
out_specs=(P(ShardingAxisName.MLP_DATA)),
|
|
204
|
+
check_vma=False,
|
|
205
|
+
)(lhs, rhs, rhs_scale, rhs_bias, group_sizes)
|
|
203
206
|
|
|
204
|
-
return gmm_result
|
|
207
|
+
return gmm_result.astype(lhs.dtype)
|
|
205
208
|
|
|
206
209
|
|
|
207
210
|
def expert_sharded_gmm(
|
|
208
211
|
lhs: jax.Array,
|
|
209
212
|
rhs: jax.Array,
|
|
213
|
+
rhs_scale: jax.Array | None,
|
|
214
|
+
rhs_bias: jax.Array | None,
|
|
210
215
|
group_sizes: jax.Array,
|
|
211
|
-
|
|
216
|
+
is_last_expert: bool,
|
|
212
217
|
mesh: Mesh,
|
|
213
|
-
num_experts: int,
|
|
214
|
-
ep_size: int,
|
|
215
218
|
) -> jax.Array:
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
|
|
220
|
-
|
|
219
|
+
ep_size = get_mesh_shape_product(mesh, ShardingAxisName.MLP_TENSOR)
|
|
220
|
+
ep_p_spec = P(ShardingAxisName.EXPERT)
|
|
221
|
+
num_experts = rhs.shape[0]
|
|
221
222
|
num_experts_per_shard = num_experts // ep_size
|
|
222
223
|
group_offset = jnp.arange(0, num_experts, num_experts_per_shard)
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
# sharded function, it has only 1 element and `group_offset.shape` is
|
|
229
|
-
# (1,) but gmm kernel requires the group_offset to be a ()-shaped array,
|
|
230
|
-
# so we group_offset[0].
|
|
231
|
-
group_offset_of_shard = group_offset[0]
|
|
224
|
+
|
|
225
|
+
def _gmm(lhs, rhs, rhs_scale, rhs_bias, group_sizes, group_offset):
|
|
226
|
+
m, g, n, k = lhs.shape[0], *rhs.shape
|
|
227
|
+
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
|
|
228
|
+
|
|
232
229
|
gmm_res = gmm(
|
|
233
230
|
lhs=lhs,
|
|
234
231
|
rhs=rhs,
|
|
232
|
+
rhs_scale=rhs_scale,
|
|
233
|
+
rhs_bias=rhs_bias,
|
|
235
234
|
group_sizes=group_sizes,
|
|
236
235
|
preferred_element_type=lhs.dtype,
|
|
237
236
|
tiling=(tm, tk, tn),
|
|
238
|
-
transpose_rhs=
|
|
239
|
-
group_offset=
|
|
237
|
+
transpose_rhs=True,
|
|
238
|
+
group_offset=group_offset[0],
|
|
240
239
|
)
|
|
241
240
|
return gmm_res
|
|
242
241
|
|
|
@@ -258,35 +257,43 @@ def expert_sharded_gmm(
|
|
|
258
257
|
# 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0 D, D, D, D
|
|
259
258
|
# shard-0 shard-1 shard-2 shard-3
|
|
260
259
|
# Each shards has 3 (row A), 2 (row B), 5 (row C) and 4 (row D).
|
|
261
|
-
|
|
260
|
+
lhs_spec = ep_p_spec if is_last_expert else P()
|
|
261
|
+
rhs_spec = ep_p_spec
|
|
262
|
+
rhs_scale_spec = None if rhs_scale is None else ep_p_spec
|
|
263
|
+
rhs_bias_spec = None if rhs_bias is None else ep_p_spec
|
|
264
|
+
gmm_res = jax.shard_map(
|
|
262
265
|
_gmm,
|
|
263
266
|
mesh=mesh,
|
|
264
|
-
in_specs=(
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
267
|
+
in_specs=(
|
|
268
|
+
lhs_spec,
|
|
269
|
+
rhs_spec,
|
|
270
|
+
rhs_scale_spec,
|
|
271
|
+
rhs_bias_spec,
|
|
272
|
+
P(),
|
|
273
|
+
ep_p_spec,
|
|
274
|
+
),
|
|
275
|
+
out_specs=ep_p_spec,
|
|
276
|
+
check_vma=False,
|
|
277
|
+
)(lhs, rhs, rhs_scale, rhs_bias, group_sizes, group_offset)
|
|
278
|
+
|
|
279
|
+
if not is_last_expert:
|
|
280
|
+
return gmm_res
|
|
268
281
|
|
|
269
282
|
# For i-th shard, it is responsible groups (AKA experts) from
|
|
270
283
|
# i*num_experts_per_shard to (i+1)*num_experts_per_shard We sum them up to
|
|
271
284
|
# get total rows in that shard, and that is the size for shard to send to
|
|
272
285
|
# its peers. This is also the number of non-zero rows from the gmm results.
|
|
273
|
-
# In the working example, send_sizes would be [3, 2, 5, 4]
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
]
|
|
286
|
+
# In the working example, send_sizes would be [3, 2, 5, 4].
|
|
287
|
+
|
|
288
|
+
# group_sizes has shape of [num_tokens_per_shard * num_experts_per_shard].
|
|
289
|
+
# So reshaping to [num_tokens_per_shard, num_experts_per_shard] and applying
|
|
290
|
+
# sum(axis=1) will get desired send_sizes shaped [num_tokens_per_shard].
|
|
291
|
+
send_sizes = group_sizes.reshape(-1, num_experts_per_shard).sum(axis=1)
|
|
278
292
|
# In the working example, input_offsets would be [0, 3, 5, 10]
|
|
279
293
|
input_offsets = jnp.concatenate((jnp.array([0]), send_sizes.cumsum()[:-1]))
|
|
280
294
|
output_offsets = input_offsets
|
|
281
295
|
recv_sizes = send_sizes
|
|
282
296
|
|
|
283
|
-
input_offsets = jax.lax.with_sharding_constraint(
|
|
284
|
-
input_offsets, NamedSharding(mesh, P("model")))
|
|
285
|
-
send_sizes = jax.lax.with_sharding_constraint(
|
|
286
|
-
send_sizes, NamedSharding(mesh, P("model")))
|
|
287
|
-
output_offsets = jax.lax.with_sharding_constraint(
|
|
288
|
-
output_offsets, NamedSharding(mesh, P("model")))
|
|
289
|
-
|
|
290
297
|
def _ragged_all_to_all(operand, input_offsets, send_sizes, output_offsets,
|
|
291
298
|
recv_sizes):
|
|
292
299
|
output = jnp.zeros_like(operand)
|
|
@@ -317,7 +324,7 @@ def expert_sharded_gmm(
|
|
|
317
324
|
send_sizes_of_shard,
|
|
318
325
|
output_offsets_of_shard,
|
|
319
326
|
recv_sizes_of_shard,
|
|
320
|
-
axis_name=
|
|
327
|
+
axis_name=ShardingAxisName.EXPERT)
|
|
321
328
|
|
|
322
329
|
# Use ragged_all_to_all to send the result from gmm for each expert to all
|
|
323
330
|
# the shards. In the working example, the result would be:
|
|
@@ -336,56 +343,74 @@ def expert_sharded_gmm(
|
|
|
336
343
|
# D, D, D, D D, D, D, D D, D, D, D D, D, D, D
|
|
337
344
|
# D, D, D, D D, D, D, D D, D, D, D D, D, D, D
|
|
338
345
|
# shard-0 shard-1 shard-2 shard-3
|
|
339
|
-
return shard_map(
|
|
346
|
+
return jax.shard_map(
|
|
340
347
|
_ragged_all_to_all,
|
|
341
348
|
mesh=mesh,
|
|
342
|
-
in_specs=(
|
|
343
|
-
out_specs=(P()),
|
|
344
|
-
|
|
349
|
+
in_specs=(ep_p_spec, ep_p_spec, ep_p_spec, ep_p_spec, P()),
|
|
350
|
+
out_specs=(P(ShardingAxisName.MLP_DATA)),
|
|
351
|
+
check_vma=False,
|
|
345
352
|
)(gmm_res, input_offsets, send_sizes, output_offsets, recv_sizes)
|
|
346
353
|
|
|
347
354
|
|
|
355
|
+
@functools.partial(
|
|
356
|
+
jax.jit,
|
|
357
|
+
static_argnames=(
|
|
358
|
+
"topk",
|
|
359
|
+
"renormalize",
|
|
360
|
+
"mesh",
|
|
361
|
+
"use_ep",
|
|
362
|
+
"activation",
|
|
363
|
+
),
|
|
364
|
+
)
|
|
348
365
|
def fused_moe_func(
|
|
349
366
|
hidden_states: jax.Array,
|
|
350
367
|
w1: jax.Array,
|
|
351
368
|
w2: jax.Array,
|
|
369
|
+
w1_scale: jax.Array | None,
|
|
370
|
+
w2_scale: jax.Array | None,
|
|
352
371
|
w1_bias: jax.Array | None,
|
|
353
372
|
w2_bias: jax.Array | None,
|
|
354
373
|
gating_output: jax.Array,
|
|
355
374
|
topk: int,
|
|
356
|
-
global_num_experts: int,
|
|
357
375
|
renormalize: bool,
|
|
358
|
-
reduce_results: bool,
|
|
359
376
|
mesh: Mesh,
|
|
360
377
|
use_ep: bool,
|
|
361
378
|
activation: str,
|
|
362
|
-
):
|
|
363
|
-
"""
|
|
379
|
+
) -> jax.Array:
|
|
380
|
+
"""Route tokens in hidden_states into each experts based on routing.
|
|
381
|
+
|
|
364
382
|
Args:
|
|
365
|
-
hidden_states: [
|
|
366
|
-
w1: [num_experts, intermediate_size * 2, hidden_size]
|
|
367
|
-
w2: [num_experts, hidden_size, intermediate_size]
|
|
368
|
-
|
|
383
|
+
hidden_states: [num_tokens, hidden_size]
|
|
384
|
+
w1: first moe weights [num_experts, intermediate_size * 2, hidden_size]
|
|
385
|
+
w2: second moe weights [num_experts, hidden_size, intermediate_size]
|
|
386
|
+
w1_scale: w1 scale [num_experts, num_blocks, 1, intermediate_size * 2]
|
|
387
|
+
w2_scale: w2 scale [num_experts, num_blocks, 1, hidden_size]
|
|
388
|
+
w1_bias: optional bias of w1 [num_experts, 1, intermediate_size * 2]
|
|
389
|
+
w2_bias: optional bias of w2 [num_experts, 1, hidden_size]
|
|
390
|
+
gating_output: routing information of tokens [num_tokens, num_experts]
|
|
391
|
+
topk: number of experts to choose per token.
|
|
392
|
+
renormalize: normalize gating_output.
|
|
393
|
+
mesh: mesh to perform moe.
|
|
394
|
+
use_ep: use expert parallelism.
|
|
395
|
+
activation: activation function to perform on the output of w1.
|
|
396
|
+
|
|
397
|
+
Returns:
|
|
398
|
+
Output of moe operation [num_tokens, hidden_size]
|
|
369
399
|
"""
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
raise NotImplementedError(
|
|
373
|
-
"Bias is not supported when using expert parallelism.")
|
|
374
|
-
orig_shape = hidden_states.shape
|
|
375
|
-
hidden_size = hidden_states.shape[-1]
|
|
376
|
-
num_tokens = hidden_states.size // hidden_size
|
|
377
|
-
assert global_num_experts == w1.shape[0]
|
|
378
|
-
ep_size = mesh.shape["model"] # only used if use_ep is True.
|
|
379
|
-
intermediate_size = w2.shape[-1]
|
|
400
|
+
num_tokens, hidden_size = hidden_states.shape
|
|
401
|
+
global_num_experts, _, padded_hidden_size = w1.shape
|
|
380
402
|
dtype = hidden_states.dtype
|
|
403
|
+
|
|
381
404
|
assert (num_tokens * topk) % 16 == 0, (
|
|
382
405
|
"The kernel requires num_tokens * topk to be a multiple of "
|
|
383
406
|
f"16 but got {num_tokens}*{topk}={num_tokens*topk}")
|
|
384
407
|
|
|
385
|
-
|
|
386
|
-
gating_output = gating_output.reshape(num_tokens, global_num_experts)
|
|
408
|
+
assert gating_output.shape == (num_tokens, global_num_experts)
|
|
387
409
|
|
|
388
410
|
topk_weights = jax.nn.softmax(gating_output.astype(jnp.float32), axis=-1)
|
|
411
|
+
# All-gather topk weights for attention dp
|
|
412
|
+
topk_weights = jax.lax.with_sharding_constraint(
|
|
413
|
+
topk_weights, NamedSharding(mesh, P(ShardingAxisName.MLP_DATA, None)))
|
|
389
414
|
topk_weights, topk_indices = jax.lax.top_k(topk_weights, k=topk)
|
|
390
415
|
if renormalize:
|
|
391
416
|
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdims=True)
|
|
@@ -405,152 +430,77 @@ def fused_moe_func(
|
|
|
405
430
|
x = hidden_states_local[token_indices_sorted]
|
|
406
431
|
return x, group_sizes_local, topk_argsort_revert_indices
|
|
407
432
|
|
|
408
|
-
x, group_sizes, topk_argsort_revert_indices = shard_map(
|
|
433
|
+
x, group_sizes, topk_argsort_revert_indices = jax.shard_map(
|
|
409
434
|
_process_tokens_locally,
|
|
410
435
|
mesh=mesh,
|
|
411
|
-
in_specs=(P(
|
|
412
|
-
|
|
413
|
-
|
|
436
|
+
in_specs=(P(ShardingAxisName.MLP_DATA,
|
|
437
|
+
None), P(ShardingAxisName.MLP_DATA, None)),
|
|
438
|
+
out_specs=(P(ShardingAxisName.MLP_DATA, None),
|
|
439
|
+
P(ShardingAxisName.MLP_DATA), P(ShardingAxisName.MLP_DATA)),
|
|
414
440
|
)(hidden_states, topk_indices)
|
|
441
|
+
|
|
442
|
+
x = jnp.pad(x, ((0, 0), (0, padded_hidden_size - hidden_size)))
|
|
443
|
+
|
|
415
444
|
if use_ep:
|
|
416
445
|
x = expert_sharded_gmm(
|
|
417
446
|
x,
|
|
418
447
|
w1,
|
|
419
|
-
|
|
420
|
-
transpose_rhs=True,
|
|
421
|
-
mesh=mesh,
|
|
422
|
-
num_experts=global_num_experts,
|
|
423
|
-
ep_size=ep_size,
|
|
424
|
-
)
|
|
425
|
-
x1, x2 = x[..., :intermediate_size], x[..., intermediate_size:]
|
|
426
|
-
else:
|
|
427
|
-
x1, x2 = tensor_sharded_gmm_merged_column_parallel(
|
|
428
|
-
x,
|
|
429
|
-
w1,
|
|
448
|
+
w1_scale,
|
|
430
449
|
w1_bias,
|
|
431
450
|
group_sizes,
|
|
432
|
-
|
|
451
|
+
is_last_expert=False,
|
|
433
452
|
mesh=mesh,
|
|
434
|
-
intermediate_size=intermediate_size,
|
|
435
453
|
)
|
|
454
|
+
x1, x2 = jnp.split(x, 2, -1)
|
|
436
455
|
|
|
437
|
-
|
|
456
|
+
x = activation_fn(activation, x1, x2)
|
|
438
457
|
|
|
439
|
-
if use_ep:
|
|
440
458
|
x = expert_sharded_gmm(
|
|
441
459
|
x,
|
|
442
460
|
w2,
|
|
461
|
+
w2_scale,
|
|
462
|
+
w2_bias,
|
|
443
463
|
group_sizes,
|
|
444
|
-
|
|
464
|
+
is_last_expert=True,
|
|
445
465
|
mesh=mesh,
|
|
446
|
-
num_experts=global_num_experts,
|
|
447
|
-
ep_size=ep_size,
|
|
448
466
|
)
|
|
449
467
|
else:
|
|
450
|
-
|
|
451
|
-
x,
|
|
468
|
+
x1, x2 = tensor_sharded_gmm_merged_column_parallel(
|
|
469
|
+
x,
|
|
470
|
+
w1,
|
|
471
|
+
w1_scale,
|
|
472
|
+
w1_bias,
|
|
473
|
+
group_sizes,
|
|
474
|
+
mesh=mesh,
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
x = activation_fn(activation, x1, x2)
|
|
478
|
+
|
|
452
479
|
x = tensor_sharded_gmm_row_parallel(
|
|
453
480
|
x,
|
|
454
481
|
w2,
|
|
482
|
+
w2_scale,
|
|
455
483
|
w2_bias,
|
|
456
484
|
group_sizes,
|
|
457
|
-
transpose_rhs=True,
|
|
458
485
|
mesh=mesh,
|
|
459
486
|
)
|
|
460
487
|
|
|
461
488
|
def _finalize_output(x_local, topk_argsort_revert_indices_local,
|
|
462
489
|
topk_weights_local):
|
|
463
490
|
x_local = x_local[topk_argsort_revert_indices_local].reshape(
|
|
464
|
-
-1, topk,
|
|
491
|
+
-1, topk, padded_hidden_size)
|
|
465
492
|
x_local = x_local * jnp.expand_dims(topk_weights_local, axis=-1)
|
|
466
493
|
x_local = x_local.sum(axis=-2)
|
|
467
494
|
return x_local
|
|
468
495
|
|
|
469
|
-
x = shard_map(
|
|
496
|
+
x = jax.shard_map(
|
|
470
497
|
_finalize_output,
|
|
471
498
|
mesh=mesh,
|
|
472
|
-
in_specs=(P(
|
|
473
|
-
|
|
474
|
-
|
|
499
|
+
in_specs=(P(ShardingAxisName.MLP_DATA,
|
|
500
|
+
None), P(ShardingAxisName.MLP_DATA),
|
|
501
|
+
P(ShardingAxisName.MLP_DATA, None)),
|
|
502
|
+
out_specs=(P(ShardingAxisName.ATTN_DATA, None)),
|
|
503
|
+
check_vma=False,
|
|
475
504
|
)(x, topk_argsort_revert_indices, topk_weights)
|
|
476
|
-
x = x.reshape(orig_shape)
|
|
477
505
|
|
|
478
|
-
|
|
479
|
-
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P("data")))
|
|
480
|
-
return x
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
@functools.partial(
|
|
484
|
-
jax.jit,
|
|
485
|
-
static_argnames=(
|
|
486
|
-
"topk",
|
|
487
|
-
"global_num_experts",
|
|
488
|
-
"renormalize",
|
|
489
|
-
"reduce_results",
|
|
490
|
-
"mesh",
|
|
491
|
-
"use_ep",
|
|
492
|
-
"activation",
|
|
493
|
-
),
|
|
494
|
-
)
|
|
495
|
-
def fused_moe_func_padded(
|
|
496
|
-
hidden_states: jax.Array,
|
|
497
|
-
w1: jax.Array,
|
|
498
|
-
w2: jax.Array,
|
|
499
|
-
w1_bias: jax.Array | None,
|
|
500
|
-
w2_bias: jax.Array | None,
|
|
501
|
-
gating_output: jax.Array,
|
|
502
|
-
topk: int,
|
|
503
|
-
global_num_experts: int,
|
|
504
|
-
renormalize: bool,
|
|
505
|
-
reduce_results: bool,
|
|
506
|
-
mesh: Mesh,
|
|
507
|
-
use_ep: bool,
|
|
508
|
-
activation: str,
|
|
509
|
-
):
|
|
510
|
-
# TODO(fanhongmin@google.com): Once the jax runner pads the input, we no longer need this.
|
|
511
|
-
hidden_size = hidden_states.shape[-1]
|
|
512
|
-
num_tokens = hidden_states.size // hidden_size
|
|
513
|
-
if num_tokens * topk < 16:
|
|
514
|
-
assert 16 % (num_tokens *
|
|
515
|
-
topk) == 0, f"Cannot pad to 16: {num_tokens=}, {topk=}"
|
|
516
|
-
n_repeats = 16 // (num_tokens * topk)
|
|
517
|
-
|
|
518
|
-
reps = (n_repeats, ) + (1, ) * (hidden_states.ndim - 1)
|
|
519
|
-
expanded_hidden_states = jnp.tile(hidden_states, reps)
|
|
520
|
-
|
|
521
|
-
reps = (n_repeats, ) + (1, ) * (gating_output.ndim - 1)
|
|
522
|
-
expanded_gating_output = jnp.tile(gating_output, reps)
|
|
523
|
-
|
|
524
|
-
expanded_x = fused_moe_func(
|
|
525
|
-
expanded_hidden_states,
|
|
526
|
-
w1,
|
|
527
|
-
w2,
|
|
528
|
-
w1_bias,
|
|
529
|
-
w2_bias,
|
|
530
|
-
expanded_gating_output,
|
|
531
|
-
topk,
|
|
532
|
-
global_num_experts,
|
|
533
|
-
renormalize,
|
|
534
|
-
reduce_results,
|
|
535
|
-
mesh,
|
|
536
|
-
use_ep,
|
|
537
|
-
activation,
|
|
538
|
-
)
|
|
539
|
-
x = expanded_x[:hidden_states.shape[0]]
|
|
540
|
-
return x
|
|
541
|
-
else:
|
|
542
|
-
return fused_moe_func(
|
|
543
|
-
hidden_states,
|
|
544
|
-
w1,
|
|
545
|
-
w2,
|
|
546
|
-
w1_bias,
|
|
547
|
-
w2_bias,
|
|
548
|
-
gating_output,
|
|
549
|
-
topk,
|
|
550
|
-
global_num_experts,
|
|
551
|
-
renormalize,
|
|
552
|
-
reduce_results,
|
|
553
|
-
mesh,
|
|
554
|
-
use_ep,
|
|
555
|
-
activation,
|
|
556
|
-
)
|
|
506
|
+
return x[:num_tokens, :hidden_size]
|