tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
1
14
|
"""TPU-Friendly Fused Mixture of Experts (MoE) kernel."""
|
|
2
15
|
|
|
3
16
|
import functools
|
|
@@ -7,7 +20,6 @@ import jax.numpy as jnp
|
|
|
7
20
|
from jax import lax
|
|
8
21
|
from jax._src import dtypes
|
|
9
22
|
from jax.experimental import pallas as pl
|
|
10
|
-
from jax.experimental import shard_map
|
|
11
23
|
from jax.experimental.pallas import tpu as pltpu
|
|
12
24
|
|
|
13
25
|
P = jax.sharding.PartitionSpec
|
|
@@ -20,7 +32,8 @@ def align_to(x, a):
|
|
|
20
32
|
|
|
21
33
|
|
|
22
34
|
def get_dtype_packing(dtype):
|
|
23
|
-
bits = dtypes.bit_width(dtype)
|
|
35
|
+
bits = (dtypes.bit_width(dtype)
|
|
36
|
+
if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
|
|
24
37
|
return 32 // bits
|
|
25
38
|
|
|
26
39
|
|
|
@@ -35,13 +48,50 @@ def broadcast_minor(src, shape):
|
|
|
35
48
|
axis=-1)[..., :shape[-1]]
|
|
36
49
|
|
|
37
50
|
|
|
51
|
+
def swigluoai(gate: jax.Array,
|
|
52
|
+
up: jax.Array,
|
|
53
|
+
*,
|
|
54
|
+
alpha: float = 1.702,
|
|
55
|
+
limit: float = 7.0) -> jax.Array:
|
|
56
|
+
"""Activation used in some models such as GPT-OSS."""
|
|
57
|
+
gate = jnp.clip(gate, a_max=limit)
|
|
58
|
+
up = jnp.clip(up, a_min=-limit, a_max=limit)
|
|
59
|
+
glu = gate * jax.nn.sigmoid(alpha * gate)
|
|
60
|
+
return (up + 1.0) * glu
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def activation_fn(acc1, acc3, act_fn):
|
|
64
|
+
if act_fn == "silu":
|
|
65
|
+
return jax.nn.silu(acc1) * acc3
|
|
66
|
+
elif act_fn == "gelu":
|
|
67
|
+
return jax.nn.gelu(acc1) * acc3
|
|
68
|
+
elif act_fn == "swigluoai":
|
|
69
|
+
return swigluoai(acc1, acc3)
|
|
70
|
+
else:
|
|
71
|
+
raise RuntimeError(f"Unsupported activation function: {act_fn}")
|
|
72
|
+
|
|
73
|
+
|
|
38
74
|
def ref_moe(
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
75
|
+
tokens: jax.Array, # (num_tokens, hidden_size)
|
|
76
|
+
w1: jax.Array, # (num_experts, 2, hidden_size, intermediate_size)
|
|
77
|
+
w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
|
|
78
|
+
gating_output: jax.Array, # (num_tokens, num_experts)
|
|
79
|
+
top_k: int,
|
|
80
|
+
*,
|
|
81
|
+
renormalize_topk_logits: bool = False,
|
|
82
|
+
act_fn: str = "silu",
|
|
83
|
+
subc_quant_wsz: int | None = None,
|
|
84
|
+
w1_scale:
|
|
85
|
+
(
|
|
86
|
+
jax.Array | None
|
|
87
|
+
) = None, # F32(num_experts, 2, hidden_size //subc_quant_wsz, 1, intermediate_size)
|
|
88
|
+
w2_scale:
|
|
89
|
+
(
|
|
90
|
+
jax.Array | None
|
|
91
|
+
) = None, # F32(num_experts, intermediate_size // subc_quant_wsz, 1, hidden_size)
|
|
92
|
+
b1: jax.Array
|
|
93
|
+
| None = None, # F32(num_experts, 2, 1, intermediate_size)
|
|
94
|
+
b2: jax.Array | None = None, # F32(num_experts, 1, hidden_size)
|
|
45
95
|
):
|
|
46
96
|
n_tokens = tokens.shape[0] # num_tokens
|
|
47
97
|
|
|
@@ -53,11 +103,16 @@ def ref_moe(
|
|
|
53
103
|
top_k_logits, top_k_indices = lax.top_k(
|
|
54
104
|
gating_logits, top_k) # [num_tokens, top_k], [num_tokens, top_k]
|
|
55
105
|
|
|
106
|
+
if renormalize_topk_logits:
|
|
107
|
+
top_k_logits = top_k_logits / jnp.sum(
|
|
108
|
+
top_k_logits, axis=-1, keepdims=True)
|
|
109
|
+
|
|
56
110
|
t_outputs = []
|
|
111
|
+
hidden_size, intermediate_size = w1.shape[-2:]
|
|
57
112
|
|
|
58
113
|
# Process each token individually
|
|
59
114
|
for i in range(n_tokens):
|
|
60
|
-
curr_token = jnp.expand_dims(tokens[i], axis=0) # [1,
|
|
115
|
+
curr_token = jnp.expand_dims(tokens[i], axis=0) # [1, hidden_size]
|
|
61
116
|
assigned_expert_ids = top_k_indices[
|
|
62
117
|
i] # [top_k] - indices of selected experts for token i
|
|
63
118
|
tok_expert_act = []
|
|
@@ -65,10 +120,24 @@ def ref_moe(
|
|
|
65
120
|
# Process each selected expert for the current token
|
|
66
121
|
for expert_id in assigned_expert_ids:
|
|
67
122
|
# Get expert weights
|
|
123
|
+
expert_w1 = w1[expert_id, 0].astype(jnp.float32)
|
|
124
|
+
expert_w3 = w1[expert_id, 1].astype(jnp.float32)
|
|
125
|
+
if w1_scale is not None:
|
|
126
|
+
expert_w1 *= jnp.repeat(w1_scale[expert_id, 0, :, 0],
|
|
127
|
+
subc_quant_wsz,
|
|
128
|
+
axis=0)[:hidden_size]
|
|
129
|
+
expert_w3 *= jnp.repeat(w1_scale[expert_id, 1, :, 0],
|
|
130
|
+
subc_quant_wsz,
|
|
131
|
+
axis=0)[:hidden_size]
|
|
68
132
|
expert_weight_1 = jnp.concat(
|
|
69
|
-
[
|
|
70
|
-
axis=-1) # [
|
|
71
|
-
expert_weight_2 = w2[expert_id]
|
|
133
|
+
[expert_w1, expert_w3],
|
|
134
|
+
axis=-1) # [hidden_size, 2 * intermediate_size]
|
|
135
|
+
expert_weight_2 = w2[expert_id].astype(
|
|
136
|
+
jnp.float32) # [intermediate_size, hidden_size]
|
|
137
|
+
if w2_scale is not None:
|
|
138
|
+
expert_weight_2 *= jnp.repeat(w2_scale[expert_id, :, 0],
|
|
139
|
+
subc_quant_wsz,
|
|
140
|
+
axis=0)[:intermediate_size]
|
|
72
141
|
|
|
73
142
|
# First linear layer with SwiGLU activation
|
|
74
143
|
gmm_1_out = curr_token @ expert_weight_1 # [1, 2 * intermediate_size]
|
|
@@ -77,37 +146,34 @@ def ref_moe(
|
|
|
77
146
|
gmm1_w1_proj, gmm1_w3_proj = jnp.split(
|
|
78
147
|
gmm_1_out, 2,
|
|
79
148
|
axis=-1) # [1, intermediate_size], [1, intermediate_size]
|
|
149
|
+
if b1 is not None:
|
|
150
|
+
gmm1_w1_proj += b1[expert_id:expert_id + 1, 0, 0]
|
|
151
|
+
gmm1_w3_proj += b1[expert_id:expert_id + 1, 1, 0]
|
|
80
152
|
|
|
81
153
|
# Apply gated activation: activation(gate) * up
|
|
82
|
-
|
|
83
|
-
act = jax.nn.silu(
|
|
84
|
-
gmm1_w1_proj) * gmm1_w3_proj # [1, intermediate_size]
|
|
85
|
-
elif activation == "gelu":
|
|
86
|
-
act = jax.nn.gelu(
|
|
87
|
-
gmm1_w1_proj) * gmm1_w3_proj # [1, intermediate_size]
|
|
88
|
-
else:
|
|
89
|
-
raise ValueError(
|
|
90
|
-
f"Unsupported activation: {activation}. Use 'silu' or 'gelu'."
|
|
91
|
-
)
|
|
154
|
+
act = activation_fn(gmm1_w1_proj, gmm1_w3_proj, act_fn)
|
|
92
155
|
|
|
93
156
|
# Second linear layer (down projection)
|
|
94
|
-
gmm_2_out = act @ expert_weight_2 # [1,
|
|
157
|
+
gmm_2_out = act @ expert_weight_2 # [1, hidden_size]
|
|
158
|
+
if b2 is not None:
|
|
159
|
+
gmm_2_out += b2[expert_id:expert_id + 1, 0]
|
|
95
160
|
tok_expert_act.append(gmm_2_out)
|
|
96
161
|
|
|
97
162
|
# Combine outputs from all selected experts
|
|
98
163
|
experts_act = jnp.concatenate(tok_expert_act,
|
|
99
|
-
axis=0) # [top_k,
|
|
164
|
+
axis=0) # [top_k, hidden_size]
|
|
100
165
|
|
|
101
166
|
# Weighted sum using top-k gating weights
|
|
102
167
|
top_k_weights = top_k_logits[i] # [top_k]
|
|
103
168
|
top_k_weights = jnp.expand_dims(top_k_weights, axis=1) # [top_k, 1]
|
|
104
169
|
weighted_output = jnp.sum(experts_act * top_k_weights,
|
|
105
170
|
axis=0,
|
|
106
|
-
keepdims=True) # [1,
|
|
171
|
+
keepdims=True) # [1, hidden_size]
|
|
107
172
|
|
|
108
|
-
t_outputs.append(weighted_output)
|
|
173
|
+
t_outputs.append(weighted_output.astype(tokens.dtype))
|
|
109
174
|
|
|
110
|
-
return jnp.concatenate(t_outputs,
|
|
175
|
+
return jnp.concatenate(t_outputs,
|
|
176
|
+
axis=0) # [actual_num_tokens, hidden_size]
|
|
111
177
|
|
|
112
178
|
|
|
113
179
|
def _fused_ep_moe_kernel(
|
|
@@ -115,12 +181,19 @@ def _fused_ep_moe_kernel(
|
|
|
115
181
|
tokens_hbm, # (local_num_tokens, t_packing, hidden_size // t_packing)
|
|
116
182
|
w1_hbm, # (local_num_experts, 2, hidden_size, intermediate_size)
|
|
117
183
|
w2_hbm, # (local_num_experts, intermediate_size, hidden_size)
|
|
184
|
+
# TODO(jevinjiang): We choose F32 scale for easier slicing. The extra
|
|
185
|
+
# latency should be hidden in the pipeline overlaping. But is there a better
|
|
186
|
+
# way to do this?
|
|
187
|
+
w1_scale_hbm, # None | F32(local_num_experts, 2, cdiv(hidden_size, subc_quant_wsz), 1, intermediate_size)
|
|
188
|
+
w2_scale_hbm, # None | F32(local_num_experts, cdiv(intermediate_size, subc_quant_wsz), 1, hidden_size)
|
|
189
|
+
b1_hbm, # None | F32(local_num_experts, 2, 1, intermediate_size)
|
|
190
|
+
b2_hbm, # None | F32(local_num_experts, 1, hidden_size)
|
|
118
191
|
gating_hbm, # (local_num_tokens, padded_num_experts)
|
|
119
192
|
a2a_g_hbm, # (num_experts, bt, t_packing, hidden_size // t_packing)
|
|
120
193
|
# Output
|
|
121
194
|
output_hbm, # (local_num_tokens, hidden_size)
|
|
122
195
|
# Scratch
|
|
123
|
-
t2e_routing_x2_smem, # <bt_sem_id> (2, bt,
|
|
196
|
+
t2e_routing_x2_smem, # <bt_sem_id> (2, bt, padded_top_k)
|
|
124
197
|
d2e_count_x2_smem, # <bt_sem_id> (2, num_devices, 1, padded_num_experts)
|
|
125
198
|
expert_offsets_x2_smem, # <bt_sem_id> (2, 2, padded_num_experts): for a2a_s and a2a_g
|
|
126
199
|
expert_starts_x2_smem, # <bt_sem_id> (2, 1, padded_num_experts)
|
|
@@ -136,6 +209,12 @@ def _fused_ep_moe_kernel(
|
|
|
136
209
|
b_w1_x2_vmem, # <bw_sem_id> (2, t_packing, bd1 // t_packing, bf)
|
|
137
210
|
b_w3_x2_vmem, # <bw_sem_id> (2, t_packing, bd1 // t_packing, bf)
|
|
138
211
|
b_w2_x2_vmem, # <bw_sem_id> (2, t_packing, bf, bd2 // t_packing)
|
|
212
|
+
b_w1_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf)
|
|
213
|
+
b_w3_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf)
|
|
214
|
+
b_w2_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bf // subc_quant_wsz, 1, bd2 // t_packing)
|
|
215
|
+
b_b1_x2_vmem, # None | <bw_sem_id> (2, 1, bf)
|
|
216
|
+
b_b3_x2_vmem, # None | <bw_sem_id> (2, 1, bf)
|
|
217
|
+
b_b2_x2_vmem, # None | <bw_sem_id> (2, t_packing, 1, bd2 // t_packing)
|
|
139
218
|
b_acc_vmem, # F32(bt * num_devices, 1, bf * 2)
|
|
140
219
|
### Semaphores:
|
|
141
220
|
local_sems, # (2, 5): 2 x [b_gating_sem, b_w1_sem, b_w2_sem, b_w3_sem, b_output_sem]
|
|
@@ -145,7 +224,10 @@ def _fused_ep_moe_kernel(
|
|
|
145
224
|
a2a_acc_sem,
|
|
146
225
|
*,
|
|
147
226
|
top_k: int,
|
|
227
|
+
renormalize_topk_logits: bool,
|
|
148
228
|
ep_axis_name: str,
|
|
229
|
+
act_fn: str,
|
|
230
|
+
subc_quant_wsz: int | None = None,
|
|
149
231
|
# Kernel tuning params.
|
|
150
232
|
bt: int, # Block size of local_num_tokens.
|
|
151
233
|
bf: int, # Block size of intermediate_size.
|
|
@@ -160,34 +242,58 @@ def _fused_ep_moe_kernel(
|
|
|
160
242
|
num_devices = lax.axis_size(ep_axis_name)
|
|
161
243
|
local_num_tokens = tokens_hbm.shape[0]
|
|
162
244
|
local_num_experts, intermediate_size, hidden_size = w2_hbm.shape
|
|
163
|
-
# num_experts = local_num_experts * num_devices
|
|
164
|
-
# padded_num_experts = expert_starts_x2_smem.shape[-1]
|
|
165
245
|
right_id = (my_id + 1) % num_devices
|
|
246
|
+
num_experts = a2a_g_hbm.shape[0]
|
|
247
|
+
padded_num_experts = d2e_count_x2_smem.shape[-1]
|
|
248
|
+
padded_top_k = t2e_routing_x2_smem.shape[-1]
|
|
249
|
+
assert padded_num_experts == align_to(num_experts, 128)
|
|
250
|
+
assert padded_top_k == align_to(top_k, 128)
|
|
166
251
|
|
|
167
252
|
t_dtype = tokens_hbm.dtype
|
|
168
253
|
t_packing = get_dtype_packing(t_dtype)
|
|
169
254
|
t_bitwidth = 32 // t_packing
|
|
170
255
|
assert a2a_g_hbm.dtype == t_dtype
|
|
171
|
-
assert w1_hbm.dtype ==
|
|
172
|
-
assert w2_hbm.dtype == t_dtype
|
|
256
|
+
assert w1_hbm.dtype == w2_hbm.dtype
|
|
173
257
|
|
|
174
|
-
|
|
175
|
-
assert
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
258
|
+
assert bd1 % bd1c == 0
|
|
259
|
+
assert bd2 % bd2c == 0
|
|
260
|
+
assert bf % bfc == 0
|
|
261
|
+
assert hidden_size % t_packing == 0
|
|
262
|
+
assert bd1 % t_packing == 0
|
|
263
|
+
assert bd2 % t_packing == 0
|
|
264
|
+
assert bd1c % t_packing == 0
|
|
265
|
+
assert bd2c % t_packing == 0
|
|
266
|
+
|
|
267
|
+
h_per_t_packing = hidden_size // t_packing
|
|
268
|
+
assert tokens_hbm.shape[-1] == h_per_t_packing
|
|
269
|
+
bd1_per_t_packing = bd1 // t_packing
|
|
270
|
+
bd2_per_t_packing = bd2 // t_packing
|
|
271
|
+
bd1c_per_t_packing = bd1c // t_packing
|
|
272
|
+
bd2c_per_t_packing = bd2c // t_packing
|
|
273
|
+
|
|
274
|
+
if subc_quant_wsz is not None:
|
|
275
|
+
assert subc_quant_wsz % 256 == 0
|
|
276
|
+
assert bd1c_per_t_packing == subc_quant_wsz
|
|
277
|
+
assert bfc == subc_quant_wsz
|
|
278
|
+
assert bd1 % subc_quant_wsz == 0
|
|
279
|
+
assert bf % subc_quant_wsz == 0
|
|
280
|
+
assert bd1_per_t_packing % subc_quant_wsz == 0
|
|
281
|
+
assert h_per_t_packing % subc_quant_wsz == 0
|
|
180
282
|
|
|
181
283
|
num_bt = cdiv(local_num_tokens, bt)
|
|
182
284
|
num_bf = cdiv(intermediate_size, bf)
|
|
183
285
|
num_bd1 = cdiv(hidden_size, bd1)
|
|
184
286
|
num_bd2 = cdiv(hidden_size, bd2)
|
|
185
287
|
|
|
288
|
+
def get_mesh_device_id(ep_rank):
|
|
289
|
+
dp_rank = jax.lax.axis_index("data")
|
|
290
|
+
return (dp_rank, ep_rank)
|
|
291
|
+
|
|
186
292
|
def sync_barrier():
|
|
187
293
|
barrier_sem = pltpu.get_barrier_semaphore()
|
|
188
294
|
pltpu.semaphore_signal(
|
|
189
295
|
barrier_sem,
|
|
190
|
-
device_id=(
|
|
296
|
+
device_id=get_mesh_device_id(right_id),
|
|
191
297
|
device_id_type=pltpu.DeviceIdType.MESH,
|
|
192
298
|
)
|
|
193
299
|
pltpu.semaphore_wait(barrier_sem, 1)
|
|
@@ -212,30 +318,44 @@ def _fused_ep_moe_kernel(
|
|
|
212
318
|
sem=b_gating_sem,
|
|
213
319
|
).wait()
|
|
214
320
|
|
|
215
|
-
def get_top_k(input, top_k):
|
|
321
|
+
def get_top_k(input, top_k, renormalize_topk_logits):
|
|
216
322
|
assert len(input.shape) == 2, input.shape
|
|
217
323
|
input = input.astype(jnp.float32)
|
|
324
|
+
padded_k_shape = (input.shape[0], padded_top_k)
|
|
218
325
|
top_k_logits_lst = []
|
|
219
326
|
top_k_indices_lst = []
|
|
220
327
|
t2e = jnp.zeros(input.shape, dtype=jnp.int32)
|
|
221
|
-
t2e_routing = jnp.zeros(
|
|
328
|
+
t2e_routing = jnp.zeros(padded_k_shape, dtype=jnp.int32)
|
|
222
329
|
iota = jax.lax.broadcasted_iota(jnp.int32, input.shape, 1)
|
|
330
|
+
padded_k_iota = jax.lax.broadcasted_iota(jnp.int32, padded_k_shape, 1)
|
|
331
|
+
top_k_logits_sum = jnp.zeros(padded_k_shape, jnp.float32)
|
|
332
|
+
|
|
223
333
|
for k_id in range(top_k):
|
|
224
|
-
# TODO(jevinjiang): return both top_k values and indices in
|
|
334
|
+
# TODO(jevinjiang): return both top_k values and indices in Mosaic
|
|
225
335
|
top_k_logits = jnp.broadcast_to(
|
|
226
|
-
jnp.max(input, axis=1, keepdims=True),
|
|
227
|
-
|
|
336
|
+
jnp.max(input[:, :num_experts], axis=1, keepdims=True),
|
|
337
|
+
padded_k_shape,
|
|
338
|
+
).astype(input.dtype)
|
|
228
339
|
top_k_logits_lst.append(top_k_logits)
|
|
340
|
+
if renormalize_topk_logits:
|
|
341
|
+
top_k_logits_sum += top_k_logits
|
|
229
342
|
# TODO(jevinjiang): support bf16 argmax in Mosaic
|
|
230
343
|
top_k_indices = jnp.broadcast_to(
|
|
231
|
-
jnp.argmax(input, axis=1, keepdims=True),
|
|
344
|
+
jnp.argmax(input[:, :num_experts], axis=1, keepdims=True),
|
|
345
|
+
padded_k_shape,
|
|
346
|
+
)
|
|
232
347
|
top_k_indices_lst.append(top_k_indices)
|
|
233
|
-
t2e_routing = jnp.where(
|
|
234
|
-
|
|
348
|
+
t2e_routing = jnp.where(padded_k_iota == k_id, top_k_indices,
|
|
349
|
+
t2e_routing)
|
|
350
|
+
mask = iota == broadcast_minor(top_k_indices, input.shape)
|
|
235
351
|
t2e += mask.astype(jnp.int32)
|
|
236
352
|
if k_id != top_k - 1:
|
|
237
353
|
input = jnp.where(mask, -jnp.inf, input)
|
|
238
354
|
|
|
355
|
+
if renormalize_topk_logits:
|
|
356
|
+
for k_id in range(top_k):
|
|
357
|
+
top_k_logits_lst[k_id] /= top_k_logits_sum
|
|
358
|
+
|
|
239
359
|
expert_sizes = jnp.sum(t2e, axis=0, keepdims=True)
|
|
240
360
|
expert_starts = jnp.zeros_like(expert_sizes)
|
|
241
361
|
return top_k_logits_lst, t2e_routing, expert_sizes, expert_starts
|
|
@@ -277,7 +397,7 @@ def _fused_ep_moe_kernel(
|
|
|
277
397
|
dst_ref=d2e_count_vmem.at[row_id],
|
|
278
398
|
send_sem=send_sem,
|
|
279
399
|
recv_sem=recv_sem,
|
|
280
|
-
device_id=(
|
|
400
|
+
device_id=get_mesh_device_id(right_id),
|
|
281
401
|
device_id_type=pltpu.DeviceIdType.MESH,
|
|
282
402
|
).wait()
|
|
283
403
|
row_id = (row_id + num_devices - 1) % num_devices
|
|
@@ -359,10 +479,8 @@ def _fused_ep_moe_kernel(
|
|
|
359
479
|
pl.ds(start, remote_sz)],
|
|
360
480
|
send_sem=send_sems.at[e_sem_id],
|
|
361
481
|
recv_sem=recv_sems.at[e_sem_id],
|
|
362
|
-
device_id=(
|
|
363
|
-
|
|
364
|
-
recv_id,
|
|
365
|
-
),
|
|
482
|
+
device_id=get_mesh_device_id(recv_id),
|
|
483
|
+
device_id_type=pltpu.DeviceIdType.MESH,
|
|
366
484
|
).start()
|
|
367
485
|
a2a_s_sends_x2_smem[e_sem_id] = send_sz
|
|
368
486
|
|
|
@@ -406,7 +524,8 @@ def _fused_ep_moe_kernel(
|
|
|
406
524
|
dst_ref=a2a_g_hbm.at[my_e_id, pl.ds(0, remote_sz)],
|
|
407
525
|
send_sem=send_sems.at[e_sem_id],
|
|
408
526
|
recv_sem=a2a_gather_sem,
|
|
409
|
-
device_id=(
|
|
527
|
+
device_id=get_mesh_device_id(recv_id),
|
|
528
|
+
device_id_type=pltpu.DeviceIdType.MESH,
|
|
410
529
|
).start()
|
|
411
530
|
start += sz
|
|
412
531
|
|
|
@@ -435,68 +554,173 @@ def _fused_ep_moe_kernel(
|
|
|
435
554
|
|
|
436
555
|
def start_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id):
|
|
437
556
|
for p in range(t_packing):
|
|
438
|
-
offset = p *
|
|
557
|
+
offset = p * h_per_t_packing + bd1_id * bd1_per_t_packing
|
|
439
558
|
pltpu.make_async_copy(
|
|
440
559
|
src_ref=w1_hbm.at[
|
|
441
560
|
local_e_id,
|
|
442
561
|
0,
|
|
443
|
-
pl.ds(offset,
|
|
562
|
+
pl.ds(offset, bd1_per_t_packing),
|
|
444
563
|
pl.ds(bf_id * bf, bf),
|
|
445
564
|
],
|
|
446
565
|
dst_ref=b_w1_x2_vmem.at[bw1_sem_id, p],
|
|
447
566
|
sem=local_sems.at[bw1_sem_id, 1],
|
|
448
567
|
).start()
|
|
568
|
+
if w1_scale_hbm is not None:
|
|
569
|
+
assert subc_quant_wsz is not None
|
|
570
|
+
pltpu.make_async_copy(
|
|
571
|
+
src_ref=w1_scale_hbm.at[
|
|
572
|
+
local_e_id,
|
|
573
|
+
0,
|
|
574
|
+
pl.ds(
|
|
575
|
+
offset // subc_quant_wsz,
|
|
576
|
+
bd1_per_t_packing // subc_quant_wsz,
|
|
577
|
+
),
|
|
578
|
+
pl.ds(0, 1),
|
|
579
|
+
pl.ds(bf_id * bf, bf),
|
|
580
|
+
],
|
|
581
|
+
dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id, p],
|
|
582
|
+
sem=local_sems.at[bw1_sem_id, 1],
|
|
583
|
+
).start()
|
|
584
|
+
if b1_hbm is not None and bd1_id == 0:
|
|
585
|
+
pltpu.make_async_copy(
|
|
586
|
+
src_ref=b1_hbm.at[local_e_id, 0,
|
|
587
|
+
pl.ds(0, 1),
|
|
588
|
+
pl.ds(bf_id * bf, bf)],
|
|
589
|
+
dst_ref=b_b1_x2_vmem.at[bf_id % 2],
|
|
590
|
+
sem=local_sems.at[bw1_sem_id, 1],
|
|
591
|
+
).start()
|
|
449
592
|
|
|
450
593
|
def start_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id):
|
|
451
594
|
for p in range(t_packing):
|
|
452
|
-
offset = p *
|
|
595
|
+
offset = p * h_per_t_packing + bd2_id * bd2_per_t_packing
|
|
453
596
|
pltpu.make_async_copy(
|
|
454
597
|
src_ref=w2_hbm.at[
|
|
455
598
|
local_e_id,
|
|
456
599
|
pl.ds(bf_id * bf, bf),
|
|
457
|
-
pl.ds(offset,
|
|
600
|
+
pl.ds(offset, bd2_per_t_packing),
|
|
458
601
|
],
|
|
459
602
|
dst_ref=b_w2_x2_vmem.at[bw2_sem_id, p],
|
|
460
603
|
sem=local_sems.at[bw2_sem_id, 2],
|
|
461
604
|
).start()
|
|
605
|
+
if w2_scale_hbm is not None:
|
|
606
|
+
assert subc_quant_wsz is not None
|
|
607
|
+
pltpu.make_async_copy(
|
|
608
|
+
src_ref=w2_scale_hbm.at[
|
|
609
|
+
local_e_id,
|
|
610
|
+
pl.ds(bf_id * bf // subc_quant_wsz, bf //
|
|
611
|
+
subc_quant_wsz),
|
|
612
|
+
pl.ds(0, 1),
|
|
613
|
+
pl.ds(offset, bd2_per_t_packing),
|
|
614
|
+
],
|
|
615
|
+
dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id, p],
|
|
616
|
+
sem=local_sems.at[bw2_sem_id, 2],
|
|
617
|
+
).start()
|
|
618
|
+
if b2_hbm is not None and bf_id == 0:
|
|
619
|
+
pltpu.make_async_copy(
|
|
620
|
+
src_ref=b2_hbm.at[local_e_id,
|
|
621
|
+
pl.ds(0, 1),
|
|
622
|
+
pl.ds(offset, bd2_per_t_packing)],
|
|
623
|
+
dst_ref=b_b2_x2_vmem.at[bd2_id % 2, p],
|
|
624
|
+
sem=local_sems.at[bw2_sem_id, 2],
|
|
625
|
+
).start()
|
|
462
626
|
|
|
463
627
|
def start_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id):
|
|
464
628
|
for p in range(t_packing):
|
|
465
|
-
offset = p *
|
|
629
|
+
offset = p * h_per_t_packing + bd3_id * bd1_per_t_packing
|
|
466
630
|
pltpu.make_async_copy(
|
|
467
631
|
src_ref=w1_hbm.at[
|
|
468
632
|
local_e_id,
|
|
469
633
|
1,
|
|
470
|
-
pl.ds(offset,
|
|
634
|
+
pl.ds(offset, bd1_per_t_packing),
|
|
471
635
|
pl.ds(bf_id * bf, bf),
|
|
472
636
|
],
|
|
473
637
|
dst_ref=b_w3_x2_vmem.at[bw3_sem_id, p],
|
|
474
638
|
sem=local_sems.at[bw3_sem_id, 3],
|
|
475
639
|
).start()
|
|
640
|
+
if w1_scale_hbm is not None:
|
|
641
|
+
assert subc_quant_wsz is not None
|
|
642
|
+
pltpu.make_async_copy(
|
|
643
|
+
src_ref=w1_scale_hbm.at[
|
|
644
|
+
local_e_id,
|
|
645
|
+
1,
|
|
646
|
+
pl.ds(
|
|
647
|
+
offset // subc_quant_wsz,
|
|
648
|
+
bd1_per_t_packing // subc_quant_wsz,
|
|
649
|
+
),
|
|
650
|
+
pl.ds(0, 1),
|
|
651
|
+
pl.ds(bf_id * bf, bf),
|
|
652
|
+
],
|
|
653
|
+
dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id, p],
|
|
654
|
+
sem=local_sems.at[bw3_sem_id, 3],
|
|
655
|
+
).start()
|
|
656
|
+
if b1_hbm is not None and bd3_id == 0:
|
|
657
|
+
pltpu.make_async_copy(
|
|
658
|
+
src_ref=b1_hbm.at[local_e_id, 1,
|
|
659
|
+
pl.ds(0, 1),
|
|
660
|
+
pl.ds(bf_id * bf, bf)],
|
|
661
|
+
dst_ref=b_b3_x2_vmem.at[bf_id % 2],
|
|
662
|
+
sem=local_sems.at[bw3_sem_id, 3],
|
|
663
|
+
).start()
|
|
476
664
|
|
|
477
665
|
def wait_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id):
|
|
478
|
-
del local_e_id
|
|
666
|
+
del local_e_id
|
|
479
667
|
pltpu.make_async_copy(
|
|
480
668
|
src_ref=b_w1_x2_vmem.at[bw1_sem_id],
|
|
481
669
|
dst_ref=b_w1_x2_vmem.at[bw1_sem_id],
|
|
482
670
|
sem=local_sems.at[bw1_sem_id, 1],
|
|
483
671
|
).wait()
|
|
672
|
+
if w1_scale_hbm is not None:
|
|
673
|
+
pltpu.make_async_copy(
|
|
674
|
+
src_ref=b_w1_scale_x2_vmem.at[bw1_sem_id],
|
|
675
|
+
dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id],
|
|
676
|
+
sem=local_sems.at[bw1_sem_id, 1],
|
|
677
|
+
).wait()
|
|
678
|
+
if b1_hbm is not None and bd1_id == 0:
|
|
679
|
+
pltpu.make_async_copy(
|
|
680
|
+
src_ref=b_b1_x2_vmem.at[bf_id % 2],
|
|
681
|
+
dst_ref=b_b1_x2_vmem.at[bf_id % 2],
|
|
682
|
+
sem=local_sems.at[bw1_sem_id, 1],
|
|
683
|
+
).wait()
|
|
484
684
|
|
|
485
685
|
def wait_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id):
|
|
486
|
-
del local_e_id
|
|
686
|
+
del local_e_id
|
|
487
687
|
pltpu.make_async_copy(
|
|
488
688
|
src_ref=b_w2_x2_vmem.at[bw2_sem_id],
|
|
489
689
|
dst_ref=b_w2_x2_vmem.at[bw2_sem_id],
|
|
490
690
|
sem=local_sems.at[bw2_sem_id, 2],
|
|
491
691
|
).wait()
|
|
692
|
+
if w2_scale_hbm is not None:
|
|
693
|
+
pltpu.make_async_copy(
|
|
694
|
+
src_ref=b_w2_scale_x2_vmem.at[bw2_sem_id],
|
|
695
|
+
dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id],
|
|
696
|
+
sem=local_sems.at[bw2_sem_id, 2],
|
|
697
|
+
).wait()
|
|
698
|
+
if b2_hbm is not None and bf_id == 0:
|
|
699
|
+
pltpu.make_async_copy(
|
|
700
|
+
src_ref=b_b2_x2_vmem.at[bd2_id % 2],
|
|
701
|
+
dst_ref=b_b2_x2_vmem.at[bd2_id % 2],
|
|
702
|
+
sem=local_sems.at[bw2_sem_id, 2],
|
|
703
|
+
).wait()
|
|
492
704
|
|
|
493
705
|
def wait_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id):
|
|
494
|
-
del local_e_id
|
|
706
|
+
del local_e_id
|
|
495
707
|
pltpu.make_async_copy(
|
|
496
708
|
src_ref=b_w3_x2_vmem.at[bw3_sem_id],
|
|
497
709
|
dst_ref=b_w3_x2_vmem.at[bw3_sem_id],
|
|
498
710
|
sem=local_sems.at[bw3_sem_id, 3],
|
|
499
711
|
).wait()
|
|
712
|
+
if w1_scale_hbm is not None:
|
|
713
|
+
pltpu.make_async_copy(
|
|
714
|
+
src_ref=b_w3_scale_x2_vmem.at[bw3_sem_id],
|
|
715
|
+
dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id],
|
|
716
|
+
sem=local_sems.at[bw3_sem_id, 3],
|
|
717
|
+
).wait()
|
|
718
|
+
if b1_hbm is not None and bd3_id == 0:
|
|
719
|
+
pltpu.make_async_copy(
|
|
720
|
+
src_ref=b_b3_x2_vmem.at[bf_id % 2],
|
|
721
|
+
dst_ref=b_b3_x2_vmem.at[bf_id % 2],
|
|
722
|
+
sem=local_sems.at[bw3_sem_id, 3],
|
|
723
|
+
).wait()
|
|
500
724
|
|
|
501
725
|
def start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, bd2_id):
|
|
502
726
|
next_bd1_id = bd1_id + 1
|
|
@@ -520,18 +744,38 @@ def _fused_ep_moe_kernel(
|
|
|
520
744
|
def dynamic_ffn1(
|
|
521
745
|
t_b32_vmem,
|
|
522
746
|
w1_vmem,
|
|
747
|
+
w1_scale_vmem,
|
|
748
|
+
b1_vmem,
|
|
523
749
|
w3_vmem,
|
|
750
|
+
w3_scale_vmem,
|
|
751
|
+
b3_vmem,
|
|
524
752
|
acc1_vmem,
|
|
525
753
|
acc3_vmem,
|
|
526
754
|
dyn_sz,
|
|
527
755
|
should_init,
|
|
528
756
|
):
|
|
529
757
|
assert t_b32_vmem.shape == (bt * num_devices, bd1 // t_packing)
|
|
530
|
-
assert w1_vmem.shape == w3_vmem.shape == (t_packing,
|
|
758
|
+
assert w1_vmem.shape == w3_vmem.shape == (t_packing, bd1_per_t_packing,
|
|
531
759
|
bf)
|
|
532
760
|
assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf)
|
|
533
761
|
assert bd1 % (t_packing * 128) == 0, (bd1, t_packing)
|
|
534
762
|
assert bd1c % (t_packing * 128) == 0, (bd1c, t_packing)
|
|
763
|
+
if w1_scale_vmem is not None:
|
|
764
|
+
assert w1_scale_vmem.shape == (
|
|
765
|
+
t_packing,
|
|
766
|
+
bd1_per_t_packing // subc_quant_wsz,
|
|
767
|
+
1,
|
|
768
|
+
bf,
|
|
769
|
+
)
|
|
770
|
+
assert bd1c_per_t_packing == subc_quant_wsz
|
|
771
|
+
if w3_scale_vmem is not None:
|
|
772
|
+
assert w3_scale_vmem.shape == (
|
|
773
|
+
t_packing,
|
|
774
|
+
bd1_per_t_packing // subc_quant_wsz,
|
|
775
|
+
1,
|
|
776
|
+
bf,
|
|
777
|
+
)
|
|
778
|
+
assert bd1c_per_t_packing == subc_quant_wsz
|
|
535
779
|
|
|
536
780
|
num_loops = cdiv(dyn_sz, btc)
|
|
537
781
|
repack_ty = jnp.dtype(f"int{t_bitwidth}")
|
|
@@ -540,7 +784,7 @@ def _fused_ep_moe_kernel(
|
|
|
540
784
|
for bd1c_id in range(cdiv(bd1, bd1c)):
|
|
541
785
|
t_b32 = t_b32_vmem[
|
|
542
786
|
pl.ds(btc_id * btc, btc),
|
|
543
|
-
pl.ds(bd1c_id *
|
|
787
|
+
pl.ds(bd1c_id * bd1c_per_t_packing, bd1c_per_t_packing),
|
|
544
788
|
]
|
|
545
789
|
for p_id in range(t_packing):
|
|
546
790
|
t = pltpu.bitcast(t_b32.astype(repack_ty), t_dtype)
|
|
@@ -548,21 +792,64 @@ def _fused_ep_moe_kernel(
|
|
|
548
792
|
for bfc_id in range(cdiv(bf, bfc)):
|
|
549
793
|
w_slices = (
|
|
550
794
|
p_id,
|
|
551
|
-
pl.ds(bd1c_id *
|
|
552
|
-
|
|
795
|
+
pl.ds(bd1c_id * bd1c_per_t_packing,
|
|
796
|
+
bd1c_per_t_packing),
|
|
553
797
|
pl.ds(bfc_id * bfc, bfc),
|
|
554
798
|
)
|
|
555
799
|
w1 = w1_vmem[*w_slices]
|
|
556
800
|
acc1 = jnp.dot(t,
|
|
557
801
|
w1,
|
|
558
802
|
preferred_element_type=jnp.float32)
|
|
803
|
+
|
|
804
|
+
if w1_scale_vmem is not None:
|
|
805
|
+
w1_scale_slices = (
|
|
806
|
+
p_id,
|
|
807
|
+
bd1c_id,
|
|
808
|
+
pl.ds(0, 1),
|
|
809
|
+
pl.ds(bfc_id * bfc, bfc),
|
|
810
|
+
)
|
|
811
|
+
# TODO(jevinjiang): can use mosaic to load with stride 0.
|
|
812
|
+
w1_scale = jnp.broadcast_to(
|
|
813
|
+
w1_scale_vmem[*w1_scale_slices], acc1.shape)
|
|
814
|
+
acc1 *= w1_scale
|
|
815
|
+
|
|
559
816
|
w3 = w3_vmem[*w_slices]
|
|
817
|
+
|
|
560
818
|
acc3 = jnp.dot(t,
|
|
561
819
|
w3,
|
|
562
820
|
preferred_element_type=jnp.float32)
|
|
821
|
+
|
|
822
|
+
if w3_scale_vmem is not None:
|
|
823
|
+
w3_scale_slices = (
|
|
824
|
+
p_id,
|
|
825
|
+
bd1c_id,
|
|
826
|
+
pl.ds(0, 1),
|
|
827
|
+
pl.ds(bfc_id * bfc, bfc),
|
|
828
|
+
)
|
|
829
|
+
w3_scale = jnp.broadcast_to(
|
|
830
|
+
w3_scale_vmem[*w3_scale_slices], acc3.shape)
|
|
831
|
+
acc3 *= w3_scale
|
|
832
|
+
|
|
563
833
|
acc_slices = (pl.ds(btc_id * btc,
|
|
564
834
|
btc), pl.ds(bfc_id * bfc, bfc))
|
|
565
835
|
if should_init and p_id == bd1c_id == 0:
|
|
836
|
+
if b1_vmem is not None:
|
|
837
|
+
b1_scale_slices = (
|
|
838
|
+
pl.ds(0, 1),
|
|
839
|
+
pl.ds(bfc_id * bfc, bfc),
|
|
840
|
+
)
|
|
841
|
+
b1 = jnp.broadcast_to(
|
|
842
|
+
b1_vmem[*b1_scale_slices], acc1.shape)
|
|
843
|
+
acc1 += b1
|
|
844
|
+
if b3_vmem is not None:
|
|
845
|
+
b3_scale_slices = (
|
|
846
|
+
pl.ds(0, 1),
|
|
847
|
+
pl.ds(bfc_id * bfc, bfc),
|
|
848
|
+
)
|
|
849
|
+
b3 = jnp.broadcast_to(
|
|
850
|
+
b3_vmem[*b3_scale_slices], acc1.shape)
|
|
851
|
+
acc3 += b3
|
|
852
|
+
|
|
566
853
|
acc1_vmem[*acc_slices] = acc1
|
|
567
854
|
acc3_vmem[*acc_slices] = acc3
|
|
568
855
|
else:
|
|
@@ -575,22 +862,28 @@ def _fused_ep_moe_kernel(
|
|
|
575
862
|
acc1_vmem,
|
|
576
863
|
acc3_vmem,
|
|
577
864
|
w2_vmem,
|
|
865
|
+
w2_scale_vmem,
|
|
866
|
+
b2_vmem,
|
|
578
867
|
res_b32_vmem,
|
|
579
868
|
dyn_sz,
|
|
580
869
|
should_init,
|
|
581
870
|
):
|
|
582
|
-
assert res_b32_vmem.shape == (bt * num_devices,
|
|
583
|
-
assert w2_vmem.shape == (t_packing, bf,
|
|
584
|
-
w2_vmem.shape,
|
|
585
|
-
t_packing,
|
|
586
|
-
bf,
|
|
587
|
-
bd2_per_packing,
|
|
588
|
-
)
|
|
871
|
+
assert res_b32_vmem.shape == (bt * num_devices, bd2_per_t_packing)
|
|
872
|
+
assert w2_vmem.shape == (t_packing, bf, bd2_per_t_packing)
|
|
589
873
|
assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf)
|
|
590
874
|
assert bd2 % (t_packing * 128) == 0, (bd2, t_packing)
|
|
591
875
|
assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing)
|
|
592
876
|
assert t_dtype in (jnp.float32, jnp.bfloat16)
|
|
593
877
|
|
|
878
|
+
if w2_scale_vmem is not None:
|
|
879
|
+
assert w2_scale_vmem.shape == (
|
|
880
|
+
t_packing,
|
|
881
|
+
bf // subc_quant_wsz,
|
|
882
|
+
1,
|
|
883
|
+
bd2_per_t_packing,
|
|
884
|
+
)
|
|
885
|
+
assert bfc == subc_quant_wsz
|
|
886
|
+
|
|
594
887
|
num_loops = cdiv(dyn_sz, btc)
|
|
595
888
|
assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing)
|
|
596
889
|
|
|
@@ -598,22 +891,47 @@ def _fused_ep_moe_kernel(
|
|
|
598
891
|
for bd2c_id in range(cdiv(bd2, bd2c)):
|
|
599
892
|
res_lst = []
|
|
600
893
|
for p_id in range(t_packing):
|
|
601
|
-
res = jnp.zeros((btc,
|
|
894
|
+
res = jnp.zeros((btc, bd2c_per_t_packing),
|
|
895
|
+
dtype=jnp.float32)
|
|
896
|
+
|
|
897
|
+
if b2_vmem is not None and should_init:
|
|
898
|
+
b2_scale_slices = (
|
|
899
|
+
p_id,
|
|
900
|
+
pl.ds(0, 1),
|
|
901
|
+
pl.ds(bd2c_id * bd2c_per_t_packing,
|
|
902
|
+
bd2c_per_t_packing),
|
|
903
|
+
)
|
|
904
|
+
b2 = jnp.broadcast_to(b2_vmem[*b2_scale_slices],
|
|
905
|
+
res.shape)
|
|
906
|
+
res += b2
|
|
907
|
+
|
|
602
908
|
for bfc_id in range(cdiv(bf, bfc)):
|
|
603
909
|
acc_slices = (pl.ds(btc_id * btc,
|
|
604
910
|
btc), pl.ds(bfc_id * bfc, bfc))
|
|
605
911
|
acc1 = acc1_vmem[*acc_slices]
|
|
606
912
|
acc3 = acc3_vmem[*acc_slices]
|
|
607
|
-
act =
|
|
913
|
+
act = activation_fn(acc1, acc3, act_fn)
|
|
608
914
|
w2 = w2_vmem[
|
|
609
915
|
p_id,
|
|
610
916
|
pl.ds(bfc_id * bfc, bfc),
|
|
611
917
|
pl.ds(bd2c_id *
|
|
612
|
-
|
|
918
|
+
bd2c_per_t_packing, bd2c_per_t_packing),
|
|
613
919
|
]
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
920
|
+
acc = jnp.dot(act,
|
|
921
|
+
w2,
|
|
922
|
+
preferred_element_type=jnp.float32)
|
|
923
|
+
if w2_scale_vmem is not None:
|
|
924
|
+
w2_scale_slices = (
|
|
925
|
+
p_id,
|
|
926
|
+
bfc_id,
|
|
927
|
+
pl.ds(0, 1),
|
|
928
|
+
pl.ds(bd2c_id * bd2c_per_t_packing,
|
|
929
|
+
bd2c_per_t_packing),
|
|
930
|
+
)
|
|
931
|
+
w2_scale = jnp.broadcast_to(
|
|
932
|
+
w2_scale_vmem[*w2_scale_slices], acc.shape)
|
|
933
|
+
acc *= w2_scale
|
|
934
|
+
res += acc
|
|
617
935
|
res = pltpu.bitcast(res, jnp.uint32)
|
|
618
936
|
if t_packing == 2:
|
|
619
937
|
res = res >> 16 << (16 * p_id)
|
|
@@ -626,7 +944,7 @@ def _fused_ep_moe_kernel(
|
|
|
626
944
|
res |= res_lst[i]
|
|
627
945
|
sliced_res_vmem = res_b32_vmem.at[
|
|
628
946
|
pl.ds(btc_id * btc, btc),
|
|
629
|
-
pl.ds(bd2c_id *
|
|
947
|
+
pl.ds(bd2c_id * bd2c_per_t_packing, bd2c_per_t_packing),
|
|
630
948
|
]
|
|
631
949
|
if should_init:
|
|
632
950
|
sliced_res_vmem[...] = res
|
|
@@ -655,21 +973,33 @@ def _fused_ep_moe_kernel(
|
|
|
655
973
|
e_id = my_id * local_num_experts + local_e_id
|
|
656
974
|
dyn_sz = expert_sizes_x2_smem[bt_sem_id, 0, e_id]
|
|
657
975
|
|
|
658
|
-
|
|
659
|
-
|
|
976
|
+
bd1_per_t_packing = bd1 // t_packing
|
|
977
|
+
bd2_per_t_packing = bd2 // t_packing
|
|
660
978
|
|
|
661
979
|
for bf_id in range(num_bf):
|
|
662
980
|
for bd1_id in range(num_bd1):
|
|
663
981
|
start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, 0)
|
|
982
|
+
w1_scale_vmem = (None if b_w1_scale_x2_vmem is None else
|
|
983
|
+
b_w1_scale_x2_vmem.at[bw_sem_id])
|
|
984
|
+
w3_scale_vmem = (None if b_w3_scale_x2_vmem is None else
|
|
985
|
+
b_w3_scale_x2_vmem.at[bw_sem_id])
|
|
986
|
+
b1_vmem = None if b_b1_x2_vmem is None else b_b1_x2_vmem.at[
|
|
987
|
+
bf_id % 2]
|
|
988
|
+
b3_vmem = None if b_b3_x2_vmem is None else b_b3_x2_vmem.at[
|
|
989
|
+
bf_id % 2]
|
|
664
990
|
wait_fetch_bw1(local_e_id, bw_sem_id, bf_id, bd1_id)
|
|
665
991
|
wait_fetch_bw3(local_e_id, bw_sem_id, bf_id, bd1_id)
|
|
666
992
|
|
|
667
993
|
dynamic_ffn1(
|
|
668
994
|
t_b32_vmem=a2a_s_b32_vmem.at[
|
|
669
995
|
...,
|
|
670
|
-
pl.ds(bd1_id *
|
|
996
|
+
pl.ds(bd1_id * bd1_per_t_packing, bd1_per_t_packing)],
|
|
671
997
|
w1_vmem=b_w1_x2_vmem.at[bw_sem_id],
|
|
998
|
+
w1_scale_vmem=w1_scale_vmem,
|
|
999
|
+
b1_vmem=b1_vmem,
|
|
672
1000
|
w3_vmem=b_w3_x2_vmem.at[bw_sem_id],
|
|
1001
|
+
w3_scale_vmem=w3_scale_vmem,
|
|
1002
|
+
b3_vmem=b3_vmem,
|
|
673
1003
|
acc1_vmem=b_acc1_vmem,
|
|
674
1004
|
acc3_vmem=b_acc3_vmem,
|
|
675
1005
|
dyn_sz=dyn_sz,
|
|
@@ -684,13 +1014,19 @@ def _fused_ep_moe_kernel(
|
|
|
684
1014
|
if bf_id == bd2_id == 0:
|
|
685
1015
|
wait_a2a_gather_send(bt_id, e_sem_id, local_e_id - 2)
|
|
686
1016
|
|
|
1017
|
+
w2_scale_vmem = (None if b_w2_scale_x2_vmem is None else
|
|
1018
|
+
b_w2_scale_x2_vmem.at[bw_sem_id])
|
|
1019
|
+
b2_vmem = None if b_b2_x2_vmem is None else b_b2_x2_vmem.at[
|
|
1020
|
+
bd2_id % 2]
|
|
687
1021
|
dynamic_ffn2(
|
|
688
1022
|
acc1_vmem=b_acc1_vmem,
|
|
689
1023
|
acc3_vmem=b_acc3_vmem,
|
|
690
1024
|
w2_vmem=b_w2_x2_vmem.at[bw_sem_id],
|
|
1025
|
+
w2_scale_vmem=w2_scale_vmem,
|
|
1026
|
+
b2_vmem=b2_vmem,
|
|
691
1027
|
res_b32_vmem=a2a_s_acc_b32_vmem.at[
|
|
692
1028
|
...,
|
|
693
|
-
pl.ds(bd2_id *
|
|
1029
|
+
pl.ds(bd2_id * bd2_per_t_packing, bd2_per_t_packing)],
|
|
694
1030
|
dyn_sz=dyn_sz,
|
|
695
1031
|
should_init=(bf_id == 0),
|
|
696
1032
|
)
|
|
@@ -757,31 +1093,42 @@ def _fused_ep_moe_kernel(
|
|
|
757
1093
|
b_gating = b_gating_x2_vmem[bt_sem_id]
|
|
758
1094
|
b_gating_score = jax.nn.softmax(b_gating, axis=-1)
|
|
759
1095
|
top_k_logits_lst, t2e_routing, expert_sizes, expert_starts = get_top_k(
|
|
760
|
-
b_gating_score, top_k)
|
|
1096
|
+
b_gating_score, top_k, renormalize_topk_logits)
|
|
761
1097
|
|
|
762
1098
|
all_reduce_metadata(bt_sem_id, t2e_routing, expert_starts,
|
|
763
1099
|
expert_sizes)
|
|
1100
|
+
sync_barrier()
|
|
764
1101
|
|
|
1102
|
+
# Start a2a scatter for first active expert.
|
|
765
1103
|
start_a2a_scatter(bt_id=bt_id, e_sem_id=e_sem_id, local_e_id=0)
|
|
766
1104
|
|
|
767
1105
|
def run_per_expert(local_e_id, e_sem_id):
|
|
768
1106
|
sync_barrier()
|
|
1107
|
+
|
|
1108
|
+
# Prefetch weights for CURRENT active expert.
|
|
1109
|
+
# TODO(jevinjiang): It is hard to prefetch weights in previous iteration
|
|
1110
|
+
# because the expert_ffn keeps overwriting the buffers. Triple buffering
|
|
1111
|
+
# could resolve this but it takes more VMEM scratch. Need further
|
|
1112
|
+
# experiment on this.
|
|
1113
|
+
start_fetch_bw1(local_e_id, bw1_sem_id=0, bf_id=0, bd1_id=0)
|
|
1114
|
+
start_fetch_bw3(local_e_id, bw3_sem_id=0, bf_id=0, bd3_id=0)
|
|
1115
|
+
|
|
1116
|
+
# Next ids.
|
|
769
1117
|
next_e_sem_id = lax.select(e_sem_id == 0, 1, 0)
|
|
770
1118
|
next_local_e_id = local_e_id + 1
|
|
771
1119
|
|
|
1120
|
+
# Start a2a scatter for NEXT active expert.
|
|
772
1121
|
@pl.when(next_local_e_id < local_num_experts)
|
|
773
1122
|
def _():
|
|
774
1123
|
start_a2a_scatter(bt_id, next_e_sem_id, next_local_e_id)
|
|
775
1124
|
|
|
776
|
-
#
|
|
777
|
-
start_fetch_bw1(local_e_id, bw1_sem_id=0, bf_id=0, bd1_id=0)
|
|
778
|
-
start_fetch_bw3(local_e_id, bw3_sem_id=0, bf_id=0, bd3_id=0)
|
|
779
|
-
|
|
780
|
-
# Wait for a2a scatter and perform FFN for active expert.
|
|
1125
|
+
# Wait a2a scatter for CURRENT active expert.
|
|
781
1126
|
wait_a2a_scatter_recv(bt_id, e_sem_id, local_e_id)
|
|
1127
|
+
|
|
1128
|
+
# Perform FFN for CURRENT active expert.
|
|
782
1129
|
expert_ffn(bt_id, e_sem_id, local_e_id)
|
|
783
1130
|
|
|
784
|
-
#
|
|
1131
|
+
# Start a2a gather to send back tokens for CURRENT active expert.
|
|
785
1132
|
start_a2a_gather(bt_id, e_sem_id, local_e_id)
|
|
786
1133
|
|
|
787
1134
|
# A must-wait before next sync_barrier.
|
|
@@ -794,7 +1141,10 @@ def _fused_ep_moe_kernel(
|
|
|
794
1141
|
e_sem_id,
|
|
795
1142
|
unroll=False)
|
|
796
1143
|
|
|
1144
|
+
# Wait to receive a2a gather for ALL experts.
|
|
797
1145
|
wait_a2a_gather_recv_all()
|
|
1146
|
+
|
|
1147
|
+
# Accumulate results for current batch.
|
|
798
1148
|
output = bt_acc(bt_id, top_k_logits_lst)
|
|
799
1149
|
|
|
800
1150
|
# Make sure it is safe to overwrite output buffer.
|
|
@@ -827,6 +1177,9 @@ def _fused_ep_moe_kernel(
|
|
|
827
1177
|
static_argnames=[
|
|
828
1178
|
"mesh",
|
|
829
1179
|
"top_k",
|
|
1180
|
+
"renormalize_topk_logits",
|
|
1181
|
+
"act_fn",
|
|
1182
|
+
"subc_quant_wsz",
|
|
830
1183
|
"bt",
|
|
831
1184
|
"bf",
|
|
832
1185
|
"bd1",
|
|
@@ -846,6 +1199,17 @@ def fused_ep_moe(
|
|
|
846
1199
|
gating_output: jax.Array, # (num_tokens, num_experts)
|
|
847
1200
|
top_k: int,
|
|
848
1201
|
*,
|
|
1202
|
+
renormalize_topk_logits: bool = False,
|
|
1203
|
+
act_fn: str = "silu",
|
|
1204
|
+
subc_quant_wsz: int | None = None,
|
|
1205
|
+
w1_scale: (
|
|
1206
|
+
jax.Array | None
|
|
1207
|
+
) = None, # F32(num_experts, 2, hidden_size // subc_quant_wsz, 1, intermediate_size)
|
|
1208
|
+
w2_scale: (
|
|
1209
|
+
jax.Array | None
|
|
1210
|
+
) = None, # F32(num_experts, intermediate_size // subc_quant_wsz, 1, hidden_size)
|
|
1211
|
+
b1: jax.Array | None = None, # F32(num_experts, 2, 1, intermediate_size)
|
|
1212
|
+
b2: jax.Array | None = None, # F32(num_experts, 1, hidden_size)
|
|
849
1213
|
# Kernel tuning parameters.
|
|
850
1214
|
bt: int,
|
|
851
1215
|
bf: int,
|
|
@@ -855,52 +1219,164 @@ def fused_ep_moe(
|
|
|
855
1219
|
bfc: int,
|
|
856
1220
|
bd1c: int,
|
|
857
1221
|
bd2c: int,
|
|
858
|
-
ep_axis_name: str =
|
|
1222
|
+
ep_axis_name: str = "model",
|
|
859
1223
|
):
|
|
860
|
-
#
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
1224
|
+
# TODO(jevinjiang): move all these assertions to validation function.
|
|
1225
|
+
if len(mesh.shape) != 2:
|
|
1226
|
+
raise NotImplementedError("Only 2D mesh is supported.")
|
|
1227
|
+
|
|
1228
|
+
for axis_name in mesh.axis_names:
|
|
1229
|
+
if axis_name == ep_axis_name:
|
|
1230
|
+
continue
|
|
1231
|
+
if mesh.shape[axis_name] != 1:
|
|
1232
|
+
raise NotImplementedError(
|
|
1233
|
+
f"Expected all non-ep axis to have size 1 in {mesh.shape=}")
|
|
864
1234
|
|
|
865
1235
|
ep_size = mesh.shape[ep_axis_name]
|
|
866
1236
|
num_devices = ep_size
|
|
867
1237
|
|
|
868
|
-
num_tokens,
|
|
1238
|
+
num_tokens, hidden_size = tokens.shape
|
|
869
1239
|
num_experts, intermediate_size, _ = w2.shape
|
|
870
1240
|
|
|
871
|
-
|
|
872
|
-
|
|
1241
|
+
if w1.shape != (num_experts, 2, hidden_size, intermediate_size):
|
|
1242
|
+
raise ValueError(
|
|
1243
|
+
f"Expected {w1.shape=} to be"
|
|
1244
|
+
f" {(num_experts, 2, hidden_size, intermediate_size)}.")
|
|
1245
|
+
|
|
1246
|
+
if w2.shape != (num_experts, intermediate_size, hidden_size):
|
|
1247
|
+
raise ValueError(f"Expected {w2.shape=} to be"
|
|
1248
|
+
f" {(num_experts, intermediate_size, hidden_size)}.")
|
|
1249
|
+
|
|
1250
|
+
if gating_output.shape != (num_tokens, num_experts):
|
|
1251
|
+
raise ValueError(
|
|
1252
|
+
f"Expected {gating_output.shape=} to be {(num_tokens, num_experts)}."
|
|
1253
|
+
)
|
|
1254
|
+
|
|
1255
|
+
if not (0 < top_k <= num_experts):
|
|
1256
|
+
raise ValueError(
|
|
1257
|
+
f"Expected {top_k=} to be in range (0, {num_experts=}].")
|
|
1258
|
+
|
|
1259
|
+
if hidden_size % 128 != 0 or intermediate_size % 128 != 0:
|
|
1260
|
+
raise ValueError(
|
|
1261
|
+
f"Expected {hidden_size=} and {intermediate_size=} to be aligned to"
|
|
1262
|
+
" 128. Did you pad them with zeros outside the kernel?")
|
|
1263
|
+
if num_tokens % ep_size != 0:
|
|
1264
|
+
raise ValueError(
|
|
1265
|
+
f"Expected {num_tokens=} to be aligned to {ep_size=}.")
|
|
1266
|
+
if num_experts % ep_size != 0:
|
|
1267
|
+
raise ValueError(
|
|
1268
|
+
f"Expected {num_experts=} to be aligned to {ep_size=}.")
|
|
873
1269
|
|
|
874
1270
|
local_num_tokens = num_tokens // ep_size
|
|
875
1271
|
# local_num_experts = num_experts // ep_size
|
|
876
1272
|
padded_num_experts = align_to(num_experts, 128)
|
|
877
|
-
|
|
1273
|
+
padded_top_k = align_to(top_k, 128)
|
|
878
1274
|
t_dtype = tokens.dtype
|
|
879
1275
|
t_packing = get_dtype_packing(t_dtype)
|
|
880
|
-
hidden_size = align_to(actual_hidden_size, 128 * t_packing)
|
|
881
|
-
if hidden_size != actual_hidden_size:
|
|
882
|
-
tokens = jnp.pad(
|
|
883
|
-
tokens,
|
|
884
|
-
((0, 0), (0, hidden_size - actual_hidden_size)),
|
|
885
|
-
constant_values=0,
|
|
886
|
-
)
|
|
887
|
-
tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing)
|
|
888
|
-
bt = min(bt, local_num_tokens)
|
|
889
|
-
bf = min(bf, intermediate_size)
|
|
890
|
-
bd1 = min(bd1, hidden_size)
|
|
891
|
-
bd2 = min(bd2, hidden_size)
|
|
892
|
-
|
|
893
|
-
btc = min(btc, bt * num_devices)
|
|
894
|
-
bfc = min(bfc, bf)
|
|
895
|
-
bd1c = min(bd1c, bd1)
|
|
896
|
-
bd2c = min(bd2c, bd2)
|
|
897
|
-
assert bfc % 128 == 0
|
|
898
|
-
assert bd1c % (t_packing * 128) == 0
|
|
899
|
-
assert bd2c % (t_packing * 128) == 0
|
|
900
|
-
assert bf % bfc == 0
|
|
901
|
-
assert bd1 % bd1c == 0
|
|
902
|
-
assert bd2 % bd2c == 0
|
|
903
1276
|
|
|
1277
|
+
# Override bt
|
|
1278
|
+
if local_num_tokens <= t_packing * 8:
|
|
1279
|
+
bt = local_num_tokens
|
|
1280
|
+
btc = bt
|
|
1281
|
+
bt = min(local_num_tokens, bt)
|
|
1282
|
+
# The worst case is that all devices send bt to one device.
|
|
1283
|
+
btc = min(bt, btc, bt * num_devices)
|
|
1284
|
+
|
|
1285
|
+
if local_num_tokens % t_packing != 0:
|
|
1286
|
+
raise ValueError(
|
|
1287
|
+
f"Expected {local_num_tokens=} to be aligned to {t_packing=}.")
|
|
1288
|
+
|
|
1289
|
+
if bt % t_packing != 0:
|
|
1290
|
+
raise ValueError(f"Expected {bt=} to be aligned to {t_packing=}.")
|
|
1291
|
+
if local_num_tokens % bt != 0:
|
|
1292
|
+
raise ValueError(
|
|
1293
|
+
f"Expected {local_num_tokens=} to be aligned to {bt=}.")
|
|
1294
|
+
|
|
1295
|
+
if subc_quant_wsz is not None:
|
|
1296
|
+
if subc_quant_wsz <= 0:
|
|
1297
|
+
raise ValueError(f"Expected {subc_quant_wsz=} to be non-negative.")
|
|
1298
|
+
if subc_quant_wsz % 256 != 0:
|
|
1299
|
+
raise ValueError(
|
|
1300
|
+
"Expected {subc_quant_wsz=} to be aligned to 256.")
|
|
1301
|
+
if hidden_size % subc_quant_wsz != 0:
|
|
1302
|
+
raise ValueError(
|
|
1303
|
+
f"Expected {hidden_size=} to be aligned to {subc_quant_wsz=}.")
|
|
1304
|
+
if intermediate_size % subc_quant_wsz != 0:
|
|
1305
|
+
raise ValueError(
|
|
1306
|
+
f"Expected {intermediate_size=} to be aligned to {subc_quant_wsz=}."
|
|
1307
|
+
)
|
|
1308
|
+
# We force compute size of contracting dim to be subc_quant_wsz. So we can
|
|
1309
|
+
# apply same scale after matmul and accumulation.
|
|
1310
|
+
bd1c = subc_quant_wsz * t_packing
|
|
1311
|
+
bfc = subc_quant_wsz
|
|
1312
|
+
|
|
1313
|
+
if bfc % 128 != 0:
|
|
1314
|
+
raise ValueError(f"Expected {bfc=} to be aligned to 128.")
|
|
1315
|
+
if bd1c % (t_packing * 128) != 0:
|
|
1316
|
+
raise ValueError(
|
|
1317
|
+
f"Expected {bd1c=} to be aligned to {t_packing * 128}.")
|
|
1318
|
+
if bd2c % (t_packing * 128) != 0:
|
|
1319
|
+
raise ValueError(
|
|
1320
|
+
f"Expected {bd2c=} to be aligned to {t_packing * 128}.")
|
|
1321
|
+
if bf % bfc != 0:
|
|
1322
|
+
raise ValueError(f"Expected {bf=} to be aligned to {bfc=}.")
|
|
1323
|
+
if bd1 % bd1c != 0:
|
|
1324
|
+
raise ValueError(f"Expected {bd1=} to be aligned to {bd1c=}.")
|
|
1325
|
+
if bd2 % bd2c != 0:
|
|
1326
|
+
raise ValueError(f"Expected {bd2=} to be aligned to {bd2c=}.")
|
|
1327
|
+
if hidden_size % bd1 != 0 or hidden_size % bd2 != 0:
|
|
1328
|
+
raise ValueError(
|
|
1329
|
+
f"Expected {hidden_size=} to be aligned to {bd1=} and {bd2=}.")
|
|
1330
|
+
if intermediate_size % bf != 0:
|
|
1331
|
+
raise ValueError(
|
|
1332
|
+
f"Expected {intermediate_size=} to be aligned to {bf=}.")
|
|
1333
|
+
|
|
1334
|
+
# Note: we should dump scale as the kernel expected shape in the
|
|
1335
|
+
# checkpoint offline or reshape right after weight loading.
|
|
1336
|
+
if w1_scale is not None:
|
|
1337
|
+
expected_w1_scale_shape = (
|
|
1338
|
+
num_experts,
|
|
1339
|
+
2,
|
|
1340
|
+
hidden_size // subc_quant_wsz,
|
|
1341
|
+
1,
|
|
1342
|
+
intermediate_size,
|
|
1343
|
+
)
|
|
1344
|
+
if w1_scale.shape != expected_w1_scale_shape:
|
|
1345
|
+
raise ValueError(
|
|
1346
|
+
f"Expected {w1_scale.shape=} to be {expected_w1_scale_shape}.")
|
|
1347
|
+
if w1_scale.dtype != jnp.float32:
|
|
1348
|
+
w1_scale = w1_scale.astype(jnp.float32)
|
|
1349
|
+
|
|
1350
|
+
if w2_scale is not None:
|
|
1351
|
+
expected_w2_scale_shape = (
|
|
1352
|
+
num_experts,
|
|
1353
|
+
intermediate_size // subc_quant_wsz,
|
|
1354
|
+
1,
|
|
1355
|
+
hidden_size,
|
|
1356
|
+
)
|
|
1357
|
+
if w2_scale.shape != expected_w2_scale_shape:
|
|
1358
|
+
raise ValueError(
|
|
1359
|
+
f"Expected {w2_scale.shape=} to be {expected_w2_scale_shape}.")
|
|
1360
|
+
if w2_scale.dtype != jnp.float32:
|
|
1361
|
+
w2_scale = w2_scale.astype(jnp.float32)
|
|
1362
|
+
|
|
1363
|
+
if b1 is not None:
|
|
1364
|
+
expected_b1_shape = (num_experts, 2, 1, intermediate_size)
|
|
1365
|
+
if b1.shape != expected_b1_shape:
|
|
1366
|
+
raise ValueError(
|
|
1367
|
+
f"Expected {b1.shape=} to be {expected_b1_shape}.")
|
|
1368
|
+
if b1.dtype != jnp.float32:
|
|
1369
|
+
b1 = b1.astype(jnp.float32)
|
|
1370
|
+
|
|
1371
|
+
if b2 is not None:
|
|
1372
|
+
expected_b2_shape = (num_experts, 1, hidden_size)
|
|
1373
|
+
if b2.shape != expected_b2_shape:
|
|
1374
|
+
raise ValueError(
|
|
1375
|
+
f"Expected {b2.shape=} to be {expected_b2_shape}.")
|
|
1376
|
+
if b2.dtype != jnp.float32:
|
|
1377
|
+
b2 = b2.astype(jnp.float32)
|
|
1378
|
+
|
|
1379
|
+
# Prepare inputs for the kernel.
|
|
904
1380
|
if padded_num_experts != gating_output.shape[-1]:
|
|
905
1381
|
gating_output = jnp.pad(
|
|
906
1382
|
gating_output,
|
|
@@ -908,128 +1384,229 @@ def fused_ep_moe(
|
|
|
908
1384
|
constant_values=-jnp.inf,
|
|
909
1385
|
)
|
|
910
1386
|
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
1387
|
+
tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing)
|
|
1388
|
+
|
|
1389
|
+
hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM)
|
|
1390
|
+
renorm_str = "-renorm_k" if renormalize_topk_logits else ""
|
|
1391
|
+
scope_name = f"fused-moe-k_{top_k}{renorm_str}-bt_{bt}_{btc}-bf_{bf}_{bfc}-bd1_{bd1}_{bd1c}-bd2_{bd2}_{bd2c}"
|
|
1392
|
+
fused_moe = pl.pallas_call(
|
|
1393
|
+
functools.partial(
|
|
1394
|
+
_fused_ep_moe_kernel,
|
|
1395
|
+
top_k=top_k,
|
|
1396
|
+
renormalize_topk_logits=renormalize_topk_logits,
|
|
1397
|
+
ep_axis_name=ep_axis_name,
|
|
1398
|
+
act_fn=act_fn,
|
|
1399
|
+
subc_quant_wsz=subc_quant_wsz,
|
|
1400
|
+
bt=bt,
|
|
1401
|
+
bf=bf,
|
|
1402
|
+
bd1=bd1,
|
|
1403
|
+
bd2=bd2,
|
|
1404
|
+
btc=btc,
|
|
1405
|
+
bfc=bfc,
|
|
1406
|
+
bd1c=bd1c,
|
|
1407
|
+
bd2c=bd2c,
|
|
1408
|
+
),
|
|
1409
|
+
out_shape=jax.ShapeDtypeStruct((local_num_tokens, hidden_size),
|
|
1410
|
+
t_dtype),
|
|
1411
|
+
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
1412
|
+
num_scalar_prefetch=0,
|
|
1413
|
+
in_specs=[
|
|
1414
|
+
hbm_block_spec, # tokens_hbm
|
|
1415
|
+
hbm_block_spec, # w1_hbm
|
|
1416
|
+
hbm_block_spec, # w2_hbm
|
|
1417
|
+
None if w1_scale is None else hbm_block_spec, # w1_scale_hbm
|
|
1418
|
+
None if w2_scale is None else hbm_block_spec, # w2_scale_hbm
|
|
1419
|
+
None if b1 is None else hbm_block_spec, # b1_hbm
|
|
1420
|
+
None if b2 is None else hbm_block_spec, # b2_hbm
|
|
1421
|
+
hbm_block_spec, # gating_output_hbm
|
|
1422
|
+
hbm_block_spec, # a2a_g_hbm
|
|
1423
|
+
],
|
|
1424
|
+
out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
|
|
1425
|
+
scratch_shapes=([
|
|
1426
|
+
# t2e_routing_x2_smem
|
|
1427
|
+
pltpu.SMEM((2, bt, padded_top_k), jnp.int32),
|
|
1428
|
+
# d2e_count_x2_smem
|
|
1429
|
+
pltpu.SMEM((2, num_devices, 1, padded_num_experts), jnp.int32),
|
|
1430
|
+
# expert_offsets_x2_smem
|
|
1431
|
+
pltpu.SMEM((2, 2, padded_num_experts), jnp.int32),
|
|
1432
|
+
# expert_starts_x2_smem
|
|
1433
|
+
pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
|
|
1434
|
+
# expert_sizes_x2_smem
|
|
1435
|
+
pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
|
|
1436
|
+
# a2a_s_sends_x2_smem
|
|
1437
|
+
pltpu.SMEM((2, ), jnp.int32),
|
|
1438
|
+
# a2a_s_x2_vmem
|
|
1439
|
+
pltpu.VMEM(
|
|
1440
|
+
(
|
|
1441
|
+
2,
|
|
1442
|
+
bt * num_devices,
|
|
1443
|
+
t_packing,
|
|
1444
|
+
hidden_size // t_packing,
|
|
962
1445
|
),
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
1446
|
+
t_dtype,
|
|
1447
|
+
),
|
|
1448
|
+
# a2a_s_acc_x2_vmem
|
|
1449
|
+
pltpu.VMEM(
|
|
1450
|
+
(
|
|
1451
|
+
2,
|
|
1452
|
+
bt * num_devices,
|
|
1453
|
+
t_packing,
|
|
1454
|
+
hidden_size // t_packing,
|
|
972
1455
|
),
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1456
|
+
t_dtype,
|
|
1457
|
+
),
|
|
1458
|
+
# a2a_g_acc_vmem
|
|
1459
|
+
pltpu.VMEM((top_k, bt, t_packing, hidden_size // t_packing),
|
|
1460
|
+
t_dtype),
|
|
1461
|
+
# b_gating_x2_vmem
|
|
1462
|
+
pltpu.VMEM((2, bt, padded_num_experts), t_dtype),
|
|
1463
|
+
# b_output_x2_vmem
|
|
1464
|
+
pltpu.VMEM((2, bt, hidden_size), t_dtype),
|
|
1465
|
+
# b_w1_x2_vmem
|
|
1466
|
+
pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
|
|
1467
|
+
# b_w3_x2_vmem
|
|
1468
|
+
pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
|
|
1469
|
+
# b_w2_x2_vmem
|
|
1470
|
+
pltpu.VMEM((2, t_packing, bf, bd2 // t_packing), w2.dtype),
|
|
1471
|
+
# b_w1_scale_x2_vmem
|
|
1472
|
+
(None if w1_scale is None else pltpu.VMEM(
|
|
1473
|
+
(
|
|
1474
|
+
2,
|
|
1475
|
+
t_packing,
|
|
1476
|
+
bd1 // t_packing // subc_quant_wsz,
|
|
1477
|
+
1,
|
|
1478
|
+
bf,
|
|
1479
|
+
),
|
|
1480
|
+
jnp.float32,
|
|
1481
|
+
)),
|
|
1482
|
+
# b_w3_scale_x2_vmem
|
|
1483
|
+
(None if w1_scale is None else pltpu.VMEM(
|
|
1484
|
+
(
|
|
1485
|
+
2,
|
|
1486
|
+
t_packing,
|
|
1487
|
+
bd1 // t_packing // subc_quant_wsz,
|
|
1488
|
+
1,
|
|
1489
|
+
bf,
|
|
1490
|
+
),
|
|
1491
|
+
jnp.float32,
|
|
1492
|
+
)),
|
|
1493
|
+
# b_w2_scale_x2_vmem
|
|
1494
|
+
(None if w2_scale is None else pltpu.VMEM(
|
|
1495
|
+
(
|
|
1496
|
+
2,
|
|
1497
|
+
t_packing,
|
|
1498
|
+
bf // subc_quant_wsz,
|
|
1499
|
+
1,
|
|
1500
|
+
bd2 // t_packing,
|
|
1501
|
+
),
|
|
1502
|
+
jnp.float32,
|
|
1503
|
+
)),
|
|
1504
|
+
# b_b1_x2_vmem
|
|
1505
|
+
(None if b1 is None else pltpu.VMEM(
|
|
1506
|
+
(
|
|
1507
|
+
2,
|
|
1508
|
+
1,
|
|
1509
|
+
bf,
|
|
1510
|
+
),
|
|
1511
|
+
jnp.float32,
|
|
1512
|
+
)),
|
|
1513
|
+
# b_b3_x2_vmem
|
|
1514
|
+
(None if b1 is None else pltpu.VMEM(
|
|
1515
|
+
(
|
|
1516
|
+
2,
|
|
1517
|
+
1,
|
|
1518
|
+
bf,
|
|
1519
|
+
),
|
|
1520
|
+
jnp.float32,
|
|
1521
|
+
)),
|
|
1522
|
+
# b_b2_x2_vmem
|
|
1523
|
+
(None if b2 is None else pltpu.VMEM(
|
|
1524
|
+
(
|
|
1525
|
+
2,
|
|
1526
|
+
t_packing,
|
|
1527
|
+
1,
|
|
1528
|
+
bd2 // t_packing,
|
|
1529
|
+
),
|
|
1530
|
+
jnp.float32,
|
|
1531
|
+
)),
|
|
1532
|
+
# b_acc_vmem
|
|
1533
|
+
pltpu.VMEM((bt * num_devices, 1, bf * 2), jnp.float32),
|
|
1534
|
+
# local_sems
|
|
1535
|
+
pltpu.SemaphoreType.DMA((2, 5)),
|
|
1536
|
+
# send_sems
|
|
1537
|
+
pltpu.SemaphoreType.DMA((2, )),
|
|
1538
|
+
# recv_sems
|
|
1539
|
+
pltpu.SemaphoreType.DMA((2, )),
|
|
1540
|
+
# a2a_gather_sem
|
|
1541
|
+
pltpu.SemaphoreType.DMA,
|
|
1542
|
+
# a2a_acc_sem
|
|
1543
|
+
pltpu.SemaphoreType.DMA,
|
|
1544
|
+
]),
|
|
1545
|
+
),
|
|
1546
|
+
compiler_params=pltpu.CompilerParams(
|
|
1547
|
+
collective_id=0,
|
|
1548
|
+
vmem_limit_bytes=100 * 1024 * 1024,
|
|
1549
|
+
),
|
|
1550
|
+
name=scope_name,
|
|
1551
|
+
)
|
|
1007
1552
|
|
|
1008
1553
|
@jax.jit
|
|
1009
|
-
@
|
|
1010
|
-
shard_map.shard_map,
|
|
1554
|
+
@jax.shard_map(
|
|
1011
1555
|
mesh=mesh,
|
|
1012
|
-
in_specs=(
|
|
1013
|
-
|
|
1556
|
+
in_specs=(
|
|
1557
|
+
P(ep_axis_name), # tokens_hbm
|
|
1558
|
+
P(ep_axis_name), # w1_hbm
|
|
1559
|
+
P(ep_axis_name), # w2_hbm
|
|
1560
|
+
None if w1_scale is None else P(ep_axis_name), # w1_scale_hbm
|
|
1561
|
+
None if w2_scale is None else P(ep_axis_name), # w2_scale_hbm
|
|
1562
|
+
None if b1 is None else P(ep_axis_name), # b1_hbm
|
|
1563
|
+
None if b2 is None else P(ep_axis_name), # b2_hbm
|
|
1564
|
+
P(ep_axis_name), # gating_output_hbm
|
|
1565
|
+
P(), # a2a_g_hbm
|
|
1566
|
+
),
|
|
1014
1567
|
out_specs=P(ep_axis_name),
|
|
1015
|
-
|
|
1568
|
+
check_vma=False,
|
|
1016
1569
|
)
|
|
1017
|
-
def kernel(
|
|
1570
|
+
def kernel(
|
|
1571
|
+
tokens,
|
|
1572
|
+
w1,
|
|
1573
|
+
w2,
|
|
1574
|
+
w1_scale,
|
|
1575
|
+
w2_scale,
|
|
1576
|
+
b1,
|
|
1577
|
+
b2,
|
|
1578
|
+
gating_output,
|
|
1579
|
+
a2a_g_hbm_scratch,
|
|
1580
|
+
):
|
|
1018
1581
|
return fused_moe(
|
|
1019
|
-
pltpu.with_memory_space_constraint(tokens,
|
|
1020
|
-
|
|
1021
|
-
pltpu.with_memory_space_constraint(
|
|
1022
|
-
pltpu.with_memory_space_constraint(
|
|
1023
|
-
pltpu.with_memory_space_constraint(
|
|
1582
|
+
pltpu.with_memory_space_constraint(tokens,
|
|
1583
|
+
pltpu.HBM), # tokens_hbm
|
|
1584
|
+
pltpu.with_memory_space_constraint(w1, pltpu.HBM), # w1_hbm
|
|
1585
|
+
pltpu.with_memory_space_constraint(w2, pltpu.HBM), # w2_hbm
|
|
1586
|
+
(None if w1_scale is None else pltpu.with_memory_space_constraint(
|
|
1587
|
+
w1_scale, pltpu.HBM)), # w1_scale_hbm
|
|
1588
|
+
(None if w2_scale is None else pltpu.with_memory_space_constraint(
|
|
1589
|
+
w2_scale, pltpu.HBM)), # w2_scale_hbm
|
|
1590
|
+
(None if b1 is None else pltpu.with_memory_space_constraint(
|
|
1591
|
+
b1, pltpu.HBM)), # b1_hbm
|
|
1592
|
+
(None if b2 is None else pltpu.with_memory_space_constraint(
|
|
1593
|
+
b2, pltpu.HBM)), # b2_hbm
|
|
1594
|
+
pltpu.with_memory_space_constraint(gating_output,
|
|
1595
|
+
pltpu.HBM), # gating_output_hbm
|
|
1596
|
+
pltpu.with_memory_space_constraint(a2a_g_hbm_scratch,
|
|
1597
|
+
pltpu.HBM), # a2a_g_hbm
|
|
1024
1598
|
)
|
|
1025
1599
|
|
|
1026
1600
|
a2a_g_hbm_scratch = pl.empty(
|
|
1027
1601
|
(num_experts, bt, t_packing, hidden_size // t_packing), t_dtype)
|
|
1028
|
-
|
|
1602
|
+
return kernel(
|
|
1029
1603
|
tokens,
|
|
1030
1604
|
w1,
|
|
1031
1605
|
w2,
|
|
1606
|
+
w1_scale,
|
|
1607
|
+
w2_scale,
|
|
1608
|
+
b1,
|
|
1609
|
+
b2,
|
|
1032
1610
|
gating_output,
|
|
1033
1611
|
a2a_g_hbm_scratch,
|
|
1034
1612
|
)
|
|
1035
|
-
return results[:, :actual_hidden_size]
|