tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,615 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import enum
|
|
16
|
+
from dataclasses import InitVar, dataclass
|
|
17
|
+
from functools import partial
|
|
18
|
+
from typing import Optional, Tuple
|
|
19
|
+
|
|
20
|
+
import jax
|
|
21
|
+
import jax.numpy as jnp
|
|
22
|
+
from flax import nnx
|
|
23
|
+
from flax.typing import Sharding
|
|
24
|
+
from jax.sharding import PartitionSpec
|
|
25
|
+
from jaxtyping import Float
|
|
26
|
+
from qwix._src.core.ragged_dot import ragged_dot as qwix_ragged_dot
|
|
27
|
+
from qwix._src.providers import ptq
|
|
28
|
+
|
|
29
|
+
from tpu_inference.layers.jax.base import create_param
|
|
30
|
+
from tpu_inference.layers.jax.layers import FlaxUtils
|
|
31
|
+
from tpu_inference.layers.jax.moe.moe import CombineExperts, MoE
|
|
32
|
+
from tpu_inference.models.jax.utils.qwix.qwix_utils import (
|
|
33
|
+
manually_quantize_qwix_activation, manually_quantize_qwix_weight)
|
|
34
|
+
|
|
35
|
+
modeling_flax_utils = FlaxUtils()
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class DeepSeekV3Router(nnx.Module):
|
|
40
|
+
"""Router module for Mixture-of-Experts (MoE) layers.
|
|
41
|
+
|
|
42
|
+
This module determines which experts each token should be routed to based on the input.
|
|
43
|
+
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
hidden_size: int
|
|
47
|
+
num_experts: int
|
|
48
|
+
num_experts_per_tok: int
|
|
49
|
+
n_groups: int
|
|
50
|
+
topk_groups: int
|
|
51
|
+
norm_topk_prob: bool
|
|
52
|
+
routed_scaling_factor: float
|
|
53
|
+
dtype: jnp.dtype
|
|
54
|
+
rngs: InitVar[nnx.Rngs]
|
|
55
|
+
|
|
56
|
+
# Sharding Attributes
|
|
57
|
+
activation_ffw_td: Sharding = ()
|
|
58
|
+
ed_sharding: Sharding = ()
|
|
59
|
+
e_sharding: Sharding = ()
|
|
60
|
+
|
|
61
|
+
random_init: bool = False
|
|
62
|
+
|
|
63
|
+
router_bias_dtype: jnp.dtype = jnp.float32
|
|
64
|
+
|
|
65
|
+
def get_topk_indices(self, scores_TE: Float) -> Float:
|
|
66
|
+
"""Get the topk indices of the scores.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
scores_TE: The scores to get the topk indices of. Shape (sequence, num_experts).
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
The topk indices of the scores. Shape (sequence, num_experts_per_tok).
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
scores_TE = scores_TE + self.bias_E
|
|
76
|
+
if self.n_groups > 1:
|
|
77
|
+
experts_per_group = self.num_experts // self.n_groups
|
|
78
|
+
group_scores_TGM = jnp.reshape(
|
|
79
|
+
scores_TE, (-1, self.n_groups, experts_per_group))
|
|
80
|
+
group_scores_TG2 = jax.lax.top_k(group_scores_TGM, k=2)[0]
|
|
81
|
+
group_scores_TG = jnp.sum(group_scores_TG2, axis=-1)
|
|
82
|
+
indices = jax.lax.top_k(group_scores_TG, k=self.topk_groups)[1]
|
|
83
|
+
|
|
84
|
+
mask_TG = jnp.any(jnp.arange(
|
|
85
|
+
self.n_groups)[:, None] == indices[..., None, :],
|
|
86
|
+
axis=-1)
|
|
87
|
+
mask_TE = jnp.repeat(mask_TG,
|
|
88
|
+
scores_TE.shape[-1] // mask_TG.shape[-1], -1)
|
|
89
|
+
scores_TE = jnp.where(mask_TE, scores_TE, 0.0)
|
|
90
|
+
|
|
91
|
+
indices_TX = jax.lax.top_k(scores_TE, k=self.num_experts_per_tok)[1]
|
|
92
|
+
|
|
93
|
+
return indices_TX
|
|
94
|
+
|
|
95
|
+
def __call__(self, x_TD: Float) -> Tuple[Float, Float]:
|
|
96
|
+
"""Routes tokens to top k experts.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
x_TD: Input array of shape (sequence, d_model).
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
A tuple containing:
|
|
103
|
+
- weights: Normalized weights for selected experts, shape (sequence, num_experts_per_tok).
|
|
104
|
+
- indices: Indices of selected experts, shape (sequence, num_experts_per_tok).
|
|
105
|
+
"""
|
|
106
|
+
x_TD = jnp.asarray(x_TD, self.dtype)
|
|
107
|
+
x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
|
|
108
|
+
|
|
109
|
+
scores_TE = jnp.einsum("TD,DE -> TE", x_TD, self.kernel_DE.value)
|
|
110
|
+
scores_TE = nnx.sigmoid(scores_TE)
|
|
111
|
+
|
|
112
|
+
original_scores_TE = scores_TE
|
|
113
|
+
topk_indices_TX = self.get_topk_indices(scores_TE)
|
|
114
|
+
weights_TX = jnp.take_along_axis(original_scores_TE,
|
|
115
|
+
topk_indices_TX,
|
|
116
|
+
axis=-1)
|
|
117
|
+
|
|
118
|
+
if self.norm_topk_prob:
|
|
119
|
+
weights_TX /= jnp.sum(weights_TX, axis=-1)[..., None] + 1e-20
|
|
120
|
+
|
|
121
|
+
weights_TX *= self.routed_scaling_factor
|
|
122
|
+
|
|
123
|
+
return weights_TX, topk_indices_TX
|
|
124
|
+
|
|
125
|
+
def __post_init__(self, rngs: nnx.Rngs):
|
|
126
|
+
"""Generates the router kernel (weights and bias) for routing."""
|
|
127
|
+
D = self.hidden_size
|
|
128
|
+
E = self.num_experts
|
|
129
|
+
self.kernel_DE = create_param(rngs,
|
|
130
|
+
shape=(D, E),
|
|
131
|
+
dtype=self.dtype,
|
|
132
|
+
sharding=self.ed_sharding,
|
|
133
|
+
random_init=self.random_init)
|
|
134
|
+
self.bias_E = create_param(rngs,
|
|
135
|
+
shape=(E, ),
|
|
136
|
+
dtype=self.router_bias_dtype,
|
|
137
|
+
sharding=self.e_sharding,
|
|
138
|
+
random_init=self.random_init)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
@dataclass(kw_only=True)
|
|
142
|
+
class SparseMoE(MoE):
|
|
143
|
+
"""Mixture-of-Experts (MoE) Routed MLP Layer.
|
|
144
|
+
|
|
145
|
+
This module implements a Sparse MoE layer with a router and multiple expert MLPs.
|
|
146
|
+
|
|
147
|
+
Attributes:
|
|
148
|
+
num_experts_per_tok: The number of experts each token is routed to.
|
|
149
|
+
tile_size: A tuple (batch, activation_dim, weight_dim) for GMM tiling.
|
|
150
|
+
use_megablox: If True, uses the MegaBlox GMM kernel.
|
|
151
|
+
mesh: The device mesh.
|
|
152
|
+
# TODO: need to redesign this I/O for parallelism
|
|
153
|
+
num_expert_parallelism: The size of the 'expert' mesh dimension.
|
|
154
|
+
# TODO: determine if we get it from external or extrat it in MoE class
|
|
155
|
+
is_batch_sharded_by_expert: True if batch is sharded over 'expert' dim.
|
|
156
|
+
"""
|
|
157
|
+
num_experts_per_tok: int
|
|
158
|
+
#TODO: tile size is (tile_batch_seq, tile_activation_dim, tile_weight_dim,) from MaxText
|
|
159
|
+
tile_size: tuple[int, int, int] = (128, 64, 128)
|
|
160
|
+
use_megablox: bool = False
|
|
161
|
+
mesh: jax.sharding.Mesh
|
|
162
|
+
# This should be set if and only if you have quantized your model (via Qwix)
|
|
163
|
+
quantized_dtype: Optional[jnp.dtype] = None
|
|
164
|
+
|
|
165
|
+
def __post_init__(self, rngs: nnx.Rngs):
|
|
166
|
+
super().__post_init__(rngs)
|
|
167
|
+
self.combine_experts = CombineExperts(dtype=self.dtype)
|
|
168
|
+
|
|
169
|
+
# Derive the expert sharding
|
|
170
|
+
self.expert_axis_name = self.edf_sharding[0]
|
|
171
|
+
if self.expert_axis_name is None:
|
|
172
|
+
self.num_expert_parallelism = 1
|
|
173
|
+
else:
|
|
174
|
+
self.num_expert_parallelism = self.mesh.shape[
|
|
175
|
+
self.expert_axis_name]
|
|
176
|
+
|
|
177
|
+
# Derive if data is sharded by expert
|
|
178
|
+
self.data_axis_name = self.activation_ffw_td[0]
|
|
179
|
+
self.is_batch_sharded_by_expert = (
|
|
180
|
+
self.expert_axis_name is not None) and (self.expert_axis_name
|
|
181
|
+
== self.data_axis_name)
|
|
182
|
+
|
|
183
|
+
def _sort_activations(self, inputs: jax.Array,
|
|
184
|
+
sort_indices: jax.Array) -> jax.Array:
|
|
185
|
+
"""Sorts activations(inputs) by `sort_indices` for the forward pass."""
|
|
186
|
+
return inputs[sort_indices, ...]
|
|
187
|
+
|
|
188
|
+
@staticmethod
|
|
189
|
+
def get_all_to_all_params(
|
|
190
|
+
all_shards_group_sizes,
|
|
191
|
+
shard_id,
|
|
192
|
+
num_expert_parallelism,
|
|
193
|
+
is_batch_sharded=True,
|
|
194
|
+
):
|
|
195
|
+
"""Generates params for ragged_all_to_all communication."""
|
|
196
|
+
|
|
197
|
+
class TransformStrategy(enum.Enum):
|
|
198
|
+
INPUT_OFFSET = enum.auto()
|
|
199
|
+
SEND_SIZE = enum.auto()
|
|
200
|
+
OUTPUT_OFFSET = enum.auto()
|
|
201
|
+
RECV_SIZE = enum.auto()
|
|
202
|
+
|
|
203
|
+
def transform_array(input_array, shard_id, strategy, is_batch_sharded):
|
|
204
|
+
if is_batch_sharded:
|
|
205
|
+
if strategy == TransformStrategy.INPUT_OFFSET:
|
|
206
|
+
local_array = input_array[shard_id]
|
|
207
|
+
return jnp.concatenate(
|
|
208
|
+
(jnp.array([0]), jnp.cumsum(local_array)[:-1]))
|
|
209
|
+
elif strategy == TransformStrategy.SEND_SIZE:
|
|
210
|
+
return input_array[shard_id]
|
|
211
|
+
elif strategy == TransformStrategy.OUTPUT_OFFSET:
|
|
212
|
+
zero_row = jnp.zeros((1, ) + input_array.shape[1:],
|
|
213
|
+
dtype=input_array.dtype)
|
|
214
|
+
array_with_zeros = jnp.concatenate((zero_row, input_array),
|
|
215
|
+
axis=0)
|
|
216
|
+
cumulated_array = jnp.cumsum(array_with_zeros,
|
|
217
|
+
axis=0,
|
|
218
|
+
dtype=input_array.dtype)
|
|
219
|
+
return cumulated_array[shard_id]
|
|
220
|
+
elif strategy == TransformStrategy.RECV_SIZE:
|
|
221
|
+
return input_array[:, shard_id]
|
|
222
|
+
else:
|
|
223
|
+
raise ValueError(
|
|
224
|
+
f"Unknown transform array strategy: {strategy}")
|
|
225
|
+
else:
|
|
226
|
+
if strategy == TransformStrategy.INPUT_OFFSET:
|
|
227
|
+
return jnp.zeros(num_expert_parallelism,
|
|
228
|
+
dtype=input_array.dtype)
|
|
229
|
+
elif strategy == TransformStrategy.SEND_SIZE:
|
|
230
|
+
return jnp.repeat(input_array[shard_id],
|
|
231
|
+
num_expert_parallelism)
|
|
232
|
+
elif strategy == TransformStrategy.OUTPUT_OFFSET:
|
|
233
|
+
output_offset = jnp.concatenate(
|
|
234
|
+
(jnp.array([0]),
|
|
235
|
+
jnp.cumsum(input_array[:-1])))[shard_id]
|
|
236
|
+
return jnp.repeat(output_offset, num_expert_parallelism)
|
|
237
|
+
elif strategy == TransformStrategy.RECV_SIZE:
|
|
238
|
+
return input_array
|
|
239
|
+
else:
|
|
240
|
+
raise ValueError(
|
|
241
|
+
f"Unknown transform array strategy: {strategy}")
|
|
242
|
+
|
|
243
|
+
input_offsets = transform_array(all_shards_group_sizes, shard_id,
|
|
244
|
+
TransformStrategy.INPUT_OFFSET,
|
|
245
|
+
is_batch_sharded)
|
|
246
|
+
send_sizes = transform_array(all_shards_group_sizes, shard_id,
|
|
247
|
+
TransformStrategy.SEND_SIZE,
|
|
248
|
+
is_batch_sharded)
|
|
249
|
+
output_offsets = transform_array(all_shards_group_sizes, shard_id,
|
|
250
|
+
TransformStrategy.OUTPUT_OFFSET,
|
|
251
|
+
is_batch_sharded)
|
|
252
|
+
recv_sizes = transform_array(all_shards_group_sizes, shard_id,
|
|
253
|
+
TransformStrategy.RECV_SIZE,
|
|
254
|
+
is_batch_sharded)
|
|
255
|
+
return input_offsets, send_sizes, output_offsets, recv_sizes
|
|
256
|
+
|
|
257
|
+
def _local_permute(
|
|
258
|
+
self,
|
|
259
|
+
inputs,
|
|
260
|
+
global_group_sizes,
|
|
261
|
+
local_expert_size,
|
|
262
|
+
shard_index,
|
|
263
|
+
is_offset=False,
|
|
264
|
+
global_sorted_experts=None,
|
|
265
|
+
):
|
|
266
|
+
"""Permutes tokens locally within an expert shard."""
|
|
267
|
+
# global_group_sizes: (tokens parallelism, num_total_experts)
|
|
268
|
+
# all_shard_local_sizes: (tokens parallelism, num local experts in the shard)
|
|
269
|
+
all_shard_local_sizes = jax.lax.dynamic_slice_in_dim(
|
|
270
|
+
global_group_sizes,
|
|
271
|
+
shard_index * local_expert_size,
|
|
272
|
+
local_expert_size,
|
|
273
|
+
axis=1,
|
|
274
|
+
)
|
|
275
|
+
local_sizes = all_shard_local_sizes.reshape(-1)
|
|
276
|
+
|
|
277
|
+
# local_group_size: (tokens parallelism, )
|
|
278
|
+
local_group_size = jnp.sum(all_shard_local_sizes, axis=0)
|
|
279
|
+
|
|
280
|
+
# When token replicated in devices
|
|
281
|
+
if is_offset:
|
|
282
|
+
global_sorted_shard_assignments = jnp.floor_divide(
|
|
283
|
+
global_sorted_experts, local_expert_size)
|
|
284
|
+
expert_indices = jnp.where(
|
|
285
|
+
global_sorted_shard_assignments == shard_index,
|
|
286
|
+
jnp.mod(global_sorted_experts, local_expert_size),
|
|
287
|
+
local_expert_size,
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# When token sharded in devices
|
|
291
|
+
else:
|
|
292
|
+
base_indices = jnp.mod(jnp.arange(local_sizes.shape[0]),
|
|
293
|
+
local_expert_size)
|
|
294
|
+
expert_indices = jnp.repeat(base_indices,
|
|
295
|
+
local_sizes,
|
|
296
|
+
total_repeat_length=inputs.shape[0])
|
|
297
|
+
|
|
298
|
+
sorted_indices = jnp.argsort(expert_indices)
|
|
299
|
+
# sort the inputs based on the local expert_indices
|
|
300
|
+
sorted_inputs = self._sort_activations(inputs, sorted_indices)
|
|
301
|
+
# sortted local expert id from 0 to local expert size
|
|
302
|
+
sorted_experts_ids = expert_indices[sorted_indices]
|
|
303
|
+
return (
|
|
304
|
+
sorted_inputs,
|
|
305
|
+
sorted_indices,
|
|
306
|
+
local_group_size,
|
|
307
|
+
sorted_experts_ids,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
def _permute(self, inputs_TD: Float, selected_experts_TX: jax.Array):
|
|
311
|
+
"""Global permute: Sorts tokens by assigned expert."""
|
|
312
|
+
# suffix t = T * X = total_assignments for the local tokens(T) on this device.
|
|
313
|
+
total_tokens = inputs_TD.shape[0]
|
|
314
|
+
flat_expert_indices = selected_experts_TX.flatten()
|
|
315
|
+
sort_indices_t = jnp.argsort(flat_expert_indices)
|
|
316
|
+
|
|
317
|
+
replicated_inputs_tD = jnp.repeat(inputs_TD,
|
|
318
|
+
self.num_experts_per_tok,
|
|
319
|
+
axis=0)
|
|
320
|
+
sorted_inputs_tD = self._sort_activations(replicated_inputs_tD,
|
|
321
|
+
sort_indices_t)
|
|
322
|
+
|
|
323
|
+
# number of tokens assigned to each expert
|
|
324
|
+
group_sizes_E = jnp.bincount(flat_expert_indices,
|
|
325
|
+
length=self.num_local_experts)
|
|
326
|
+
|
|
327
|
+
expert_ids = jnp.arange(self.num_local_experts)
|
|
328
|
+
total_assignments = total_tokens * self.num_experts_per_tok
|
|
329
|
+
sorted_expert_assignments_t = jnp.repeat(
|
|
330
|
+
expert_ids,
|
|
331
|
+
repeats=group_sizes_E,
|
|
332
|
+
total_repeat_length=total_assignments)
|
|
333
|
+
|
|
334
|
+
return (
|
|
335
|
+
sorted_inputs_tD,
|
|
336
|
+
sort_indices_t,
|
|
337
|
+
group_sizes_E,
|
|
338
|
+
sorted_expert_assignments_t,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
def _unpermute(self, processed_tokens: jax.Array, sort_indices: jax.Array,
|
|
342
|
+
router_weights_TX: jax.Array):
|
|
343
|
+
"""Unsorts tokens to their original order and combines expert outputs with router's weight."""
|
|
344
|
+
with jax.named_scope("unpermute"):
|
|
345
|
+
unsorted_tokens_tD = self._sort_activations(
|
|
346
|
+
processed_tokens, jnp.argsort(sort_indices))
|
|
347
|
+
reshaped_tokens_TXD = unsorted_tokens_tD.reshape(
|
|
348
|
+
-1, self.num_experts_per_tok, self.hidden_size)
|
|
349
|
+
return self.combine_experts(reshaped_tokens_TXD, router_weights_TX)
|
|
350
|
+
|
|
351
|
+
def _gmm(self, inputs, kernel, group_sizes):
|
|
352
|
+
"""Performs Grouped Matrix Multiply."""
|
|
353
|
+
num_rows = inputs.shape[0]
|
|
354
|
+
pad_amount = (self.tile_size[0] -
|
|
355
|
+
num_rows % self.tile_size[0]) % self.tile_size[0]
|
|
356
|
+
if pad_amount > 0:
|
|
357
|
+
inputs = jnp.pad(inputs, ((0, pad_amount), (0, 0)))
|
|
358
|
+
|
|
359
|
+
if self.use_megablox:
|
|
360
|
+
#TODO: megablox is used in MaxText, keep a placeholder here for future implement
|
|
361
|
+
raise NotImplementedError(
|
|
362
|
+
"MegaBlox kernel call is not implemented.")
|
|
363
|
+
else:
|
|
364
|
+
inputs = manually_quantize_qwix_activation(
|
|
365
|
+
inputs, "ragged_dot", jnp.float8_e4m3fn, [0], {},
|
|
366
|
+
"absmax") if self.quantized_dtype else inputs
|
|
367
|
+
ragged_dot_func = qwix_ragged_dot if self.quantized_dtype else jax.lax.ragged_dot
|
|
368
|
+
output = ragged_dot_func(
|
|
369
|
+
lhs=inputs,
|
|
370
|
+
rhs=kernel,
|
|
371
|
+
group_sizes=group_sizes,
|
|
372
|
+
preferred_element_type=self.dtype,
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
if pad_amount > 0:
|
|
376
|
+
output = output[:num_rows, :]
|
|
377
|
+
return output
|
|
378
|
+
|
|
379
|
+
@staticmethod
|
|
380
|
+
def _distributed_sparse_moe_fwd(
|
|
381
|
+
self,
|
|
382
|
+
x_TD: jax.Array,
|
|
383
|
+
router_weights_TX: jax.Array,
|
|
384
|
+
selected_experts_TX: jax.Array,
|
|
385
|
+
kernel_gating: jax.Array,
|
|
386
|
+
kernel_up_proj: jax.Array,
|
|
387
|
+
kernel_down_proj: jax.Array,
|
|
388
|
+
):
|
|
389
|
+
"""
|
|
390
|
+
The sparse MoE forward pass with fully distributed logic.
|
|
391
|
+
This assumes it is running within a distributed TPU.
|
|
392
|
+
"""
|
|
393
|
+
|
|
394
|
+
# 1. Global Permute, perpute all tokens across shards
|
|
395
|
+
(
|
|
396
|
+
sorted_inputs,
|
|
397
|
+
global_sort_indices,
|
|
398
|
+
global_group_sizes,
|
|
399
|
+
global_sorted_experts,
|
|
400
|
+
) = self._permute(x_TD, selected_experts_TX)
|
|
401
|
+
|
|
402
|
+
# TODO: update to 'expert' after we enable expert parallelism, currently experts are sharded along model axis
|
|
403
|
+
# or we sould derive it from the model init
|
|
404
|
+
expert_shard_id = jax.lax.axis_index(self.expert_axis_name)
|
|
405
|
+
local_expert_size = self.num_local_experts // self.num_expert_parallelism
|
|
406
|
+
|
|
407
|
+
if self.num_expert_parallelism > 1:
|
|
408
|
+
if self.is_batch_sharded_by_expert:
|
|
409
|
+
# When token sharded in devices
|
|
410
|
+
# In this path, we assume the data(tokens) are fully sharded on expert, namely data_axis_name == expert_axis_name
|
|
411
|
+
|
|
412
|
+
# 2a. Send Tokens To Experts (All-to-All)
|
|
413
|
+
# Gather group sizes from all data shards
|
|
414
|
+
# all_shards_group_sizes: (data parallelism = expert parallelism, number of total experts )
|
|
415
|
+
all_shards_group_sizes = jax.lax.all_gather(
|
|
416
|
+
global_group_sizes, axis_name=self.data_axis_name)
|
|
417
|
+
|
|
418
|
+
# all_shards_group_sizes_per_expert_shard[i][j] = # tokens on shard[i] to be sent to expert shard[j]
|
|
419
|
+
all_shards_group_sizes_per_expert_shard = jnp.sum(
|
|
420
|
+
all_shards_group_sizes.reshape(
|
|
421
|
+
self.num_expert_parallelism, # data parallelism
|
|
422
|
+
self.num_expert_parallelism, # expert parallelism
|
|
423
|
+
local_expert_size # Experts per shard
|
|
424
|
+
),
|
|
425
|
+
axis=2)
|
|
426
|
+
input_offsets, send_sizes, output_offsets, recv_sizes = self.get_all_to_all_params(
|
|
427
|
+
all_shards_group_sizes_per_expert_shard, expert_shard_id,
|
|
428
|
+
self.num_expert_parallelism)
|
|
429
|
+
# Estimate buffer size
|
|
430
|
+
local_total_assignments = x_TD.shape[
|
|
431
|
+
0] * self.num_experts_per_tok
|
|
432
|
+
global_total_assignments = local_total_assignments * self.num_expert_parallelism
|
|
433
|
+
output_shape_est = jnp.zeros(
|
|
434
|
+
(global_total_assignments, self.hidden_size),
|
|
435
|
+
dtype=sorted_inputs.dtype)
|
|
436
|
+
|
|
437
|
+
inputs_after_all2all = jax.lax.ragged_all_to_all(
|
|
438
|
+
sorted_inputs,
|
|
439
|
+
output_shape_est,
|
|
440
|
+
input_offsets,
|
|
441
|
+
send_sizes,
|
|
442
|
+
output_offsets,
|
|
443
|
+
recv_sizes,
|
|
444
|
+
axis_name=self.expert_axis_name)
|
|
445
|
+
|
|
446
|
+
# 3a. Local Permute
|
|
447
|
+
# Get full group sizes from all shards
|
|
448
|
+
full_global_group_sizes = jax.lax.all_gather(
|
|
449
|
+
global_group_sizes, axis_name=self.expert_axis_name)
|
|
450
|
+
(
|
|
451
|
+
compute_inputs,
|
|
452
|
+
local_sorted_indices,
|
|
453
|
+
compute_group_sizes,
|
|
454
|
+
compute_expert_ids,
|
|
455
|
+
) = self._local_permute(
|
|
456
|
+
inputs_after_all2all,
|
|
457
|
+
full_global_group_sizes,
|
|
458
|
+
local_expert_size,
|
|
459
|
+
shard_index=expert_shard_id,
|
|
460
|
+
is_offset=False,
|
|
461
|
+
)
|
|
462
|
+
|
|
463
|
+
else:
|
|
464
|
+
# When token replicated in devices
|
|
465
|
+
|
|
466
|
+
# 2. No send all-to-all needed, as the tokens are sorted and replicated on all devices
|
|
467
|
+
# 3b. Local "Permute"
|
|
468
|
+
(
|
|
469
|
+
compute_inputs,
|
|
470
|
+
local_sorted_indices,
|
|
471
|
+
compute_group_sizes,
|
|
472
|
+
compute_expert_ids,
|
|
473
|
+
) = self._local_permute(
|
|
474
|
+
sorted_inputs,
|
|
475
|
+
global_group_sizes[None, :],
|
|
476
|
+
local_expert_size,
|
|
477
|
+
shard_index=expert_shard_id,
|
|
478
|
+
is_offset=True,
|
|
479
|
+
global_sorted_experts=global_sorted_experts,
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
# Calculate group sizes for return all-to-all
|
|
483
|
+
reshaped_group_sizes = jnp.sum(global_group_sizes.reshape(
|
|
484
|
+
-1, local_expert_size),
|
|
485
|
+
axis=1)
|
|
486
|
+
mask = compute_expert_ids < local_expert_size
|
|
487
|
+
compute_inputs = compute_inputs * mask[..., None]
|
|
488
|
+
|
|
489
|
+
else:
|
|
490
|
+
# --- NO EXPERT PARALLELISM ---
|
|
491
|
+
compute_inputs = sorted_inputs
|
|
492
|
+
compute_group_sizes = global_group_sizes
|
|
493
|
+
compute_expert_ids = global_sorted_experts
|
|
494
|
+
local_sorted_indices = jnp.arange(sorted_inputs.shape[0])
|
|
495
|
+
|
|
496
|
+
# 4. Compute: Apply experts using Grouped Matrix Multiply
|
|
497
|
+
with jax.named_scope("gating"):
|
|
498
|
+
# compute_inputs: (local total assignments, D)
|
|
499
|
+
gating_TEF = self._gmm(compute_inputs, kernel_gating,
|
|
500
|
+
compute_group_sizes)
|
|
501
|
+
activated_gating_TEF = modeling_flax_utils.ACT2FN[self.hidden_act](
|
|
502
|
+
gating_TEF)
|
|
503
|
+
|
|
504
|
+
with jax.named_scope("up_projection"):
|
|
505
|
+
up_proj_TEF = self._gmm(compute_inputs, kernel_up_proj,
|
|
506
|
+
compute_group_sizes)
|
|
507
|
+
|
|
508
|
+
fuse_TEF = activated_gating_TEF * up_proj_TEF
|
|
509
|
+
|
|
510
|
+
with jax.named_scope("down_projection"):
|
|
511
|
+
# intermediate_output: (local total assignments, D)
|
|
512
|
+
intermediate_output = self._gmm(fuse_TEF, kernel_down_proj,
|
|
513
|
+
compute_group_sizes)
|
|
514
|
+
|
|
515
|
+
# 5. Return Results (All-to-All)
|
|
516
|
+
if self.num_expert_parallelism > 1:
|
|
517
|
+
local_total_assignments = x_TD.shape[0] * self.num_experts_per_tok
|
|
518
|
+
output_shape = jnp.zeros(
|
|
519
|
+
(local_total_assignments, self.hidden_size),
|
|
520
|
+
dtype=intermediate_output.dtype)
|
|
521
|
+
|
|
522
|
+
if self.is_batch_sharded_by_expert:
|
|
523
|
+
# When token sharded in devices
|
|
524
|
+
# Unsort locally before sending back
|
|
525
|
+
local_output = self._sort_activations(
|
|
526
|
+
intermediate_output, jnp.argsort(local_sorted_indices))
|
|
527
|
+
|
|
528
|
+
input_offsets, send_sizes, output_offsets, recv_sizes = self.get_all_to_all_params(
|
|
529
|
+
jnp.transpose(all_shards_group_sizes),
|
|
530
|
+
expert_shard_id,
|
|
531
|
+
self.num_expert_parallelism,
|
|
532
|
+
)
|
|
533
|
+
final_intermediate_output = jax.lax.ragged_all_to_all(
|
|
534
|
+
local_output,
|
|
535
|
+
output_shape,
|
|
536
|
+
input_offsets,
|
|
537
|
+
send_sizes,
|
|
538
|
+
output_offsets,
|
|
539
|
+
recv_sizes,
|
|
540
|
+
axis_name=self.expert_axis_name)
|
|
541
|
+
else:
|
|
542
|
+
# When token replicated in devices
|
|
543
|
+
input_offsets, send_sizes, output_offsets, recv_sizes = self.get_all_to_all_params(
|
|
544
|
+
reshaped_group_sizes,
|
|
545
|
+
expert_shard_id,
|
|
546
|
+
self.num_expert_parallelism,
|
|
547
|
+
is_batch_sharded=False,
|
|
548
|
+
)
|
|
549
|
+
final_intermediate_output = jax.lax.ragged_all_to_all(
|
|
550
|
+
intermediate_output,
|
|
551
|
+
output_shape,
|
|
552
|
+
input_offsets,
|
|
553
|
+
send_sizes,
|
|
554
|
+
output_offsets,
|
|
555
|
+
recv_sizes,
|
|
556
|
+
axis_name=self.expert_axis_name)
|
|
557
|
+
else:
|
|
558
|
+
final_intermediate_output = intermediate_output
|
|
559
|
+
|
|
560
|
+
# 6. Global Unpermute (on the data shard)
|
|
561
|
+
with jax.named_scope("unpermute"):
|
|
562
|
+
output_TD = self._unpermute(final_intermediate_output,
|
|
563
|
+
global_sort_indices, router_weights_TX)
|
|
564
|
+
|
|
565
|
+
return output_TD
|
|
566
|
+
|
|
567
|
+
def __call__(self, x_TD: Float):
|
|
568
|
+
"""Performs the forward pass of the Sparse MoE layer."""
|
|
569
|
+
x_TD = jnp.asarray(x_TD, self.dtype)
|
|
570
|
+
x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td)
|
|
571
|
+
router_weights_TX, selected_experts_TX = self.router(x_TD)
|
|
572
|
+
|
|
573
|
+
in_specs = (
|
|
574
|
+
PartitionSpec(), # Replicated `self`
|
|
575
|
+
PartitionSpec(*self.activation_ffw_td), # Sharded x_TD
|
|
576
|
+
PartitionSpec(), # Replicated router_weights_TX
|
|
577
|
+
PartitionSpec(), # Replicated selected_experts_TX
|
|
578
|
+
PartitionSpec(*self.edf_sharding), # Sharded gating kernel
|
|
579
|
+
PartitionSpec(*self.edf_sharding), # Sharded up-projection kernel
|
|
580
|
+
PartitionSpec(
|
|
581
|
+
*self.efd_sharding), # Sharded down-projection kernel
|
|
582
|
+
)
|
|
583
|
+
out_specs = PartitionSpec(*self.activation_ffw_td)
|
|
584
|
+
|
|
585
|
+
mapped_moe_fwd = partial(jax.shard_map,
|
|
586
|
+
mesh=self.mesh,
|
|
587
|
+
in_specs=in_specs,
|
|
588
|
+
out_specs=out_specs,
|
|
589
|
+
check_vma=False)(
|
|
590
|
+
SparseMoE._distributed_sparse_moe_fwd)
|
|
591
|
+
|
|
592
|
+
kernel_gating_EDF = self.kernel_gating_EDF.value
|
|
593
|
+
kernel_up_proj_EDF = self.kernel_up_proj_EDF.value
|
|
594
|
+
kernel_down_proj_EFD = self.kernel_down_proj_EFD.value
|
|
595
|
+
|
|
596
|
+
if self.quantized_dtype:
|
|
597
|
+
if not isinstance(kernel_gating_EDF, ptq.WithAux):
|
|
598
|
+
kernel_gating_EDF = manually_quantize_qwix_weight(
|
|
599
|
+
kernel_gating_EDF, self.quantized_dtype, [0, 2], {},
|
|
600
|
+
"absmax")
|
|
601
|
+
if not isinstance(kernel_up_proj_EDF, ptq.WithAux):
|
|
602
|
+
kernel_up_proj_EDF = manually_quantize_qwix_weight(
|
|
603
|
+
kernel_up_proj_EDF, self.quantized_dtype, [0, 2], {},
|
|
604
|
+
"absmax")
|
|
605
|
+
if not isinstance(kernel_down_proj_EFD, ptq.WithAux):
|
|
606
|
+
kernel_down_proj_EFD = manually_quantize_qwix_weight(
|
|
607
|
+
kernel_down_proj_EFD, self.quantized_dtype, [0, 1], {},
|
|
608
|
+
"absmax")
|
|
609
|
+
kernel_gating_EDF = kernel_gating_EDF.array
|
|
610
|
+
kernel_up_proj_EDF = kernel_up_proj_EDF.array
|
|
611
|
+
kernel_down_proj_EFD = kernel_down_proj_EFD.array
|
|
612
|
+
|
|
613
|
+
return mapped_moe_fwd(self, x_TD, router_weights_TX,
|
|
614
|
+
selected_experts_TX, kernel_gating_EDF,
|
|
615
|
+
kernel_up_proj_EDF, kernel_down_proj_EFD)
|