tpu-inference 0.12.0.dev20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_core_tpu.py +513 -0
- tests/core/test_disagg_executor.py +60 -0
- tests/core/test_disagg_utils.py +67 -0
- tests/core/test_dp_scheduler.py +724 -0
- tests/core/test_init.py +63 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +393 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +291 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +388 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +498 -0
- tests/kernels/quantized_matmul_kernel_test.py +159 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/layers/jax/test_qwix.py +969 -0
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +297 -0
- tests/layers/vllm/test_unquantized.py +621 -0
- tests/layers/vllm/utils.py +72 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +46 -0
- tests/lora/test_bgmv.py +57 -0
- tests/lora/test_layers.py +666 -0
- tests/lora/test_lora.py +147 -0
- tests/lora/test_lora_perf.py +67 -0
- tests/lora/utils.py +88 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +606 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +202 -0
- tests/runner/test_tpu_runner_dp.py +1033 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +215 -0
- tests/test_envs.py +280 -0
- tests/test_tpu_info.py +134 -0
- tests/test_utils.py +193 -0
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +67 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/core_tpu.py +786 -0
- tpu_inference/core/disagg_executor.py +118 -0
- tpu_inference/core/disagg_utils.py +49 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +814 -0
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +81 -0
- tpu_inference/distributed/tpu_connector.py +732 -0
- tpu_inference/distributed/utils.py +112 -0
- tpu_inference/env_override.py +9 -0
- tpu_inference/envs.py +191 -0
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +399 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +272 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/kernel.py +772 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +1340 -0
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
- tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
- tpu_inference/kernels/quantized_matmul/util.py +58 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
- tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +403 -0
- tpu_inference/layers/common/attention_metadata.py +48 -0
- tpu_inference/layers/common/binary_search.py +295 -0
- tpu_inference/layers/common/quant_methods.py +23 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +600 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +268 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
- tpu_inference/layers/jax/base.py +165 -0
- tpu_inference/layers/jax/constants.py +101 -0
- tpu_inference/layers/jax/layers.py +315 -0
- tpu_inference/layers/jax/misc.py +30 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
- tpu_inference/layers/jax/moe/moe.py +249 -0
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +294 -0
- tpu_inference/layers/jax/rope_interface.py +228 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
- tpu_inference/layers/jax/sample/sampling.py +110 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
- tpu_inference/layers/jax/transformer_block.py +121 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +221 -0
- tpu_inference/layers/vllm/fused_moe.py +502 -0
- tpu_inference/layers/vllm/linear_common.py +221 -0
- tpu_inference/layers/vllm/quantization/__init__.py +55 -0
- tpu_inference/layers/vllm/quantization/awq.py +221 -0
- tpu_inference/layers/vllm/quantization/common.py +124 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
- tpu_inference/layers/vllm/sharding.py +244 -0
- tpu_inference/logger.py +10 -0
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +98 -0
- tpu_inference/lora/torch_punica_tpu.py +310 -0
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +520 -0
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +978 -0
- tpu_inference/models/jax/gpt_oss.py +508 -0
- tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
- tpu_inference/models/jax/llama3.py +436 -0
- tpu_inference/models/jax/llama4.py +643 -0
- tpu_inference/models/jax/llama_eagle3.py +350 -0
- tpu_inference/models/jax/llama_guard_4.py +375 -0
- tpu_inference/models/jax/qwen2.py +390 -0
- tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
- tpu_inference/models/jax/qwen3.py +318 -0
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +110 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
- tpu_inference/models/jax/utils/weight_utils.py +621 -0
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
- tpu_inference/platforms/__init__.py +16 -0
- tpu_inference/platforms/tpu_platform.py +258 -0
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/block_table.py +122 -0
- tpu_inference/runner/compilation_manager.py +890 -0
- tpu_inference/runner/input_batch.py +435 -0
- tpu_inference/runner/kv_cache.py +166 -0
- tpu_inference/runner/kv_cache_manager.py +508 -0
- tpu_inference/runner/lora_utils.py +106 -0
- tpu_inference/runner/multimodal_manager.py +231 -0
- tpu_inference/runner/persistent_batch_manager.py +296 -0
- tpu_inference/runner/speculative_decoding_manager.py +262 -0
- tpu_inference/runner/structured_decoding_manager.py +101 -0
- tpu_inference/runner/tpu_runner.py +1768 -0
- tpu_inference/runner/utils.py +426 -0
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +430 -0
- tpu_inference/tpu_info.py +92 -0
- tpu_inference/utils.py +345 -0
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +468 -0
- tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
- tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
- tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
- tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
- tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
|
@@ -0,0 +1,1612 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""TPU-Friendly Fused Mixture of Experts (MoE) kernel."""
|
|
15
|
+
|
|
16
|
+
import functools
|
|
17
|
+
|
|
18
|
+
import jax
|
|
19
|
+
import jax.numpy as jnp
|
|
20
|
+
from jax import lax
|
|
21
|
+
from jax._src import dtypes
|
|
22
|
+
from jax.experimental import pallas as pl
|
|
23
|
+
from jax.experimental.pallas import tpu as pltpu
|
|
24
|
+
|
|
25
|
+
P = jax.sharding.PartitionSpec
|
|
26
|
+
|
|
27
|
+
cdiv = pl.cdiv
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def align_to(x, a):
|
|
31
|
+
return cdiv(x, a) * a
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_dtype_packing(dtype):
|
|
35
|
+
bits = (dtypes.bit_width(dtype)
|
|
36
|
+
if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
|
|
37
|
+
return 32 // bits
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def broadcast_minor(src, shape):
|
|
41
|
+
if src.shape == shape:
|
|
42
|
+
return src
|
|
43
|
+
assert src.shape[:-1] == shape[:-1]
|
|
44
|
+
assert src.shape[-1] % 128 == 0
|
|
45
|
+
target_minor = align_to(shape[-1], src.shape[-1])
|
|
46
|
+
# no-op concatenation.
|
|
47
|
+
return jnp.concatenate([src for _ in range(target_minor // src.shape[-1])],
|
|
48
|
+
axis=-1)[..., :shape[-1]]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def swigluoai(gate: jax.Array,
|
|
52
|
+
up: jax.Array,
|
|
53
|
+
*,
|
|
54
|
+
alpha: float = 1.702,
|
|
55
|
+
limit: float = 7.0) -> jax.Array:
|
|
56
|
+
"""Activation used in some models such as GPT-OSS."""
|
|
57
|
+
gate = jnp.clip(gate, a_max=limit)
|
|
58
|
+
up = jnp.clip(up, a_min=-limit, a_max=limit)
|
|
59
|
+
glu = gate * jax.nn.sigmoid(alpha * gate)
|
|
60
|
+
return (up + 1.0) * glu
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def activation_fn(acc1, acc3, act_fn):
|
|
64
|
+
if act_fn == "silu":
|
|
65
|
+
return jax.nn.silu(acc1) * acc3
|
|
66
|
+
elif act_fn == "gelu":
|
|
67
|
+
return jax.nn.gelu(acc1) * acc3
|
|
68
|
+
elif act_fn == "swigluoai":
|
|
69
|
+
return swigluoai(acc1, acc3)
|
|
70
|
+
else:
|
|
71
|
+
raise RuntimeError(f"Unsupported activation function: {act_fn}")
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def ref_moe(
|
|
75
|
+
tokens: jax.Array, # (num_tokens, hidden_size)
|
|
76
|
+
w1: jax.Array, # (num_experts, 2, hidden_size, intermediate_size)
|
|
77
|
+
w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
|
|
78
|
+
gating_output: jax.Array, # (num_tokens, num_experts)
|
|
79
|
+
top_k: int,
|
|
80
|
+
*,
|
|
81
|
+
renormalize_topk_logits: bool = False,
|
|
82
|
+
act_fn: str = "silu",
|
|
83
|
+
subc_quant_wsz: int | None = None,
|
|
84
|
+
w1_scale:
|
|
85
|
+
(
|
|
86
|
+
jax.Array | None
|
|
87
|
+
) = None, # F32(num_experts, 2, hidden_size //subc_quant_wsz, 1, intermediate_size)
|
|
88
|
+
w2_scale:
|
|
89
|
+
(
|
|
90
|
+
jax.Array | None
|
|
91
|
+
) = None, # F32(num_experts, intermediate_size // subc_quant_wsz, 1, hidden_size)
|
|
92
|
+
b1: jax.Array
|
|
93
|
+
| None = None, # F32(num_experts, 2, 1, intermediate_size)
|
|
94
|
+
b2: jax.Array | None = None, # F32(num_experts, 1, hidden_size)
|
|
95
|
+
):
|
|
96
|
+
n_tokens = tokens.shape[0] # num_tokens
|
|
97
|
+
|
|
98
|
+
# Compute gating scores for all experts
|
|
99
|
+
gating_logits = jax.nn.softmax(gating_output,
|
|
100
|
+
axis=-1) # [num_tokens, n_experts]
|
|
101
|
+
|
|
102
|
+
# Select top-k experts per token
|
|
103
|
+
top_k_logits, top_k_indices = lax.top_k(
|
|
104
|
+
gating_logits, top_k) # [num_tokens, top_k], [num_tokens, top_k]
|
|
105
|
+
|
|
106
|
+
if renormalize_topk_logits:
|
|
107
|
+
top_k_logits = top_k_logits / jnp.sum(
|
|
108
|
+
top_k_logits, axis=-1, keepdims=True)
|
|
109
|
+
|
|
110
|
+
t_outputs = []
|
|
111
|
+
hidden_size, intermediate_size = w1.shape[-2:]
|
|
112
|
+
|
|
113
|
+
# Process each token individually
|
|
114
|
+
for i in range(n_tokens):
|
|
115
|
+
curr_token = jnp.expand_dims(tokens[i], axis=0) # [1, hidden_size]
|
|
116
|
+
assigned_expert_ids = top_k_indices[
|
|
117
|
+
i] # [top_k] - indices of selected experts for token i
|
|
118
|
+
tok_expert_act = []
|
|
119
|
+
|
|
120
|
+
# Process each selected expert for the current token
|
|
121
|
+
for expert_id in assigned_expert_ids:
|
|
122
|
+
# Get expert weights
|
|
123
|
+
expert_w1 = w1[expert_id, 0].astype(jnp.float32)
|
|
124
|
+
expert_w3 = w1[expert_id, 1].astype(jnp.float32)
|
|
125
|
+
if w1_scale is not None:
|
|
126
|
+
expert_w1 *= jnp.repeat(w1_scale[expert_id, 0, :, 0],
|
|
127
|
+
subc_quant_wsz,
|
|
128
|
+
axis=0)[:hidden_size]
|
|
129
|
+
expert_w3 *= jnp.repeat(w1_scale[expert_id, 1, :, 0],
|
|
130
|
+
subc_quant_wsz,
|
|
131
|
+
axis=0)[:hidden_size]
|
|
132
|
+
expert_weight_1 = jnp.concat(
|
|
133
|
+
[expert_w1, expert_w3],
|
|
134
|
+
axis=-1) # [hidden_size, 2 * intermediate_size]
|
|
135
|
+
expert_weight_2 = w2[expert_id].astype(
|
|
136
|
+
jnp.float32) # [intermediate_size, hidden_size]
|
|
137
|
+
if w2_scale is not None:
|
|
138
|
+
expert_weight_2 *= jnp.repeat(w2_scale[expert_id, :, 0],
|
|
139
|
+
subc_quant_wsz,
|
|
140
|
+
axis=0)[:intermediate_size]
|
|
141
|
+
|
|
142
|
+
# First linear layer with SwiGLU activation
|
|
143
|
+
gmm_1_out = curr_token @ expert_weight_1 # [1, 2 * intermediate_size]
|
|
144
|
+
|
|
145
|
+
# Split into gate and up projections for SwiGLU
|
|
146
|
+
gmm1_w1_proj, gmm1_w3_proj = jnp.split(
|
|
147
|
+
gmm_1_out, 2,
|
|
148
|
+
axis=-1) # [1, intermediate_size], [1, intermediate_size]
|
|
149
|
+
if b1 is not None:
|
|
150
|
+
gmm1_w1_proj += b1[expert_id:expert_id + 1, 0, 0]
|
|
151
|
+
gmm1_w3_proj += b1[expert_id:expert_id + 1, 1, 0]
|
|
152
|
+
|
|
153
|
+
# Apply gated activation: activation(gate) * up
|
|
154
|
+
act = activation_fn(gmm1_w1_proj, gmm1_w3_proj, act_fn)
|
|
155
|
+
|
|
156
|
+
# Second linear layer (down projection)
|
|
157
|
+
gmm_2_out = act @ expert_weight_2 # [1, hidden_size]
|
|
158
|
+
if b2 is not None:
|
|
159
|
+
gmm_2_out += b2[expert_id:expert_id + 1, 0]
|
|
160
|
+
tok_expert_act.append(gmm_2_out)
|
|
161
|
+
|
|
162
|
+
# Combine outputs from all selected experts
|
|
163
|
+
experts_act = jnp.concatenate(tok_expert_act,
|
|
164
|
+
axis=0) # [top_k, hidden_size]
|
|
165
|
+
|
|
166
|
+
# Weighted sum using top-k gating weights
|
|
167
|
+
top_k_weights = top_k_logits[i] # [top_k]
|
|
168
|
+
top_k_weights = jnp.expand_dims(top_k_weights, axis=1) # [top_k, 1]
|
|
169
|
+
weighted_output = jnp.sum(experts_act * top_k_weights,
|
|
170
|
+
axis=0,
|
|
171
|
+
keepdims=True) # [1, hidden_size]
|
|
172
|
+
|
|
173
|
+
t_outputs.append(weighted_output.astype(tokens.dtype))
|
|
174
|
+
|
|
175
|
+
return jnp.concatenate(t_outputs,
|
|
176
|
+
axis=0) # [actual_num_tokens, hidden_size]
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def _fused_ep_moe_kernel(
|
|
180
|
+
# Input
|
|
181
|
+
tokens_hbm, # (local_num_tokens, t_packing, hidden_size // t_packing)
|
|
182
|
+
w1_hbm, # (local_num_experts, 2, hidden_size, intermediate_size)
|
|
183
|
+
w2_hbm, # (local_num_experts, intermediate_size, hidden_size)
|
|
184
|
+
# TODO(jevinjiang): We choose F32 scale for easier slicing. The extra
|
|
185
|
+
# latency should be hidden in the pipeline overlaping. But is there a better
|
|
186
|
+
# way to do this?
|
|
187
|
+
w1_scale_hbm, # None | F32(local_num_experts, 2, cdiv(hidden_size, subc_quant_wsz), 1, intermediate_size)
|
|
188
|
+
w2_scale_hbm, # None | F32(local_num_experts, cdiv(intermediate_size, subc_quant_wsz), 1, hidden_size)
|
|
189
|
+
b1_hbm, # None | F32(local_num_experts, 2, 1, intermediate_size)
|
|
190
|
+
b2_hbm, # None | F32(local_num_experts, 1, hidden_size)
|
|
191
|
+
gating_hbm, # (local_num_tokens, padded_num_experts)
|
|
192
|
+
a2a_g_hbm, # (num_experts, bt, t_packing, hidden_size // t_packing)
|
|
193
|
+
# Output
|
|
194
|
+
output_hbm, # (local_num_tokens, hidden_size)
|
|
195
|
+
# Scratch
|
|
196
|
+
t2e_routing_x2_smem, # <bt_sem_id> (2, bt, padded_top_k)
|
|
197
|
+
d2e_count_x2_smem, # <bt_sem_id> (2, num_devices, 1, padded_num_experts)
|
|
198
|
+
expert_offsets_x2_smem, # <bt_sem_id> (2, 2, padded_num_experts): for a2a_s and a2a_g
|
|
199
|
+
expert_starts_x2_smem, # <bt_sem_id> (2, 1, padded_num_experts)
|
|
200
|
+
expert_sizes_x2_smem, # <bt_sem_id> (2, 1, padded_num_experts)
|
|
201
|
+
a2a_s_sends_x2_smem, # <e_sem_id> (2,)
|
|
202
|
+
a2a_s_x2_vmem, # <e_sem_id> (2, bt * num_devices, t_packing, hidden_size // t_packing)
|
|
203
|
+
a2a_s_acc_x2_vmem, # <e_sem_id> (2, bt * num_devices, t_packing, hidden_size // t_packing)
|
|
204
|
+
### Accumulation for gathered tokens:
|
|
205
|
+
a2a_g_acc_vmem, # (top_k, bt, t_packing, hidden_size // t_packing)
|
|
206
|
+
### Expert weight double buffering:
|
|
207
|
+
b_gating_x2_vmem, # <bt_sem_id> (2, bt, padded_num_experts)
|
|
208
|
+
b_output_x2_vmem, # <bt_sem_id> (2, bt, hidden_size)
|
|
209
|
+
b_w1_x2_vmem, # <bw_sem_id> (2, t_packing, bd1 // t_packing, bf)
|
|
210
|
+
b_w3_x2_vmem, # <bw_sem_id> (2, t_packing, bd1 // t_packing, bf)
|
|
211
|
+
b_w2_x2_vmem, # <bw_sem_id> (2, t_packing, bf, bd2 // t_packing)
|
|
212
|
+
b_w1_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf)
|
|
213
|
+
b_w3_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf)
|
|
214
|
+
b_w2_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bf // subc_quant_wsz, 1, bd2 // t_packing)
|
|
215
|
+
b_b1_x2_vmem, # None | <bw_sem_id> (2, 1, bf)
|
|
216
|
+
b_b3_x2_vmem, # None | <bw_sem_id> (2, 1, bf)
|
|
217
|
+
b_b2_x2_vmem, # None | <bw_sem_id> (2, t_packing, 1, bd2 // t_packing)
|
|
218
|
+
b_acc_vmem, # F32(bt * num_devices, 1, bf * 2)
|
|
219
|
+
### Semaphores:
|
|
220
|
+
local_sems, # (2, 5): 2 x [b_gating_sem, b_w1_sem, b_w2_sem, b_w3_sem, b_output_sem]
|
|
221
|
+
send_sems, # <e_sem_id> (2,)
|
|
222
|
+
recv_sems, # <e_sem_id> (2,)
|
|
223
|
+
a2a_gather_sem,
|
|
224
|
+
a2a_acc_sem,
|
|
225
|
+
*,
|
|
226
|
+
top_k: int,
|
|
227
|
+
renormalize_topk_logits: bool,
|
|
228
|
+
ep_axis_name: str,
|
|
229
|
+
act_fn: str,
|
|
230
|
+
subc_quant_wsz: int | None = None,
|
|
231
|
+
# Kernel tuning params.
|
|
232
|
+
bt: int, # Block size of local_num_tokens.
|
|
233
|
+
bf: int, # Block size of intermediate_size.
|
|
234
|
+
bd1: int, # Block size of hidden_size in w1.
|
|
235
|
+
bd2: int, # Block size of hidden_size in w2.
|
|
236
|
+
btc: int, # Compute size of block tokens for active expert.
|
|
237
|
+
bfc: int, # Compute size of block intermediate_size.
|
|
238
|
+
bd1c: int, # Compute size of block hidden_size.
|
|
239
|
+
bd2c: int, # Compute size of block hidden_size.
|
|
240
|
+
):
|
|
241
|
+
my_id = lax.axis_index(ep_axis_name)
|
|
242
|
+
num_devices = lax.axis_size(ep_axis_name)
|
|
243
|
+
local_num_tokens = tokens_hbm.shape[0]
|
|
244
|
+
local_num_experts, intermediate_size, hidden_size = w2_hbm.shape
|
|
245
|
+
right_id = (my_id + 1) % num_devices
|
|
246
|
+
num_experts = a2a_g_hbm.shape[0]
|
|
247
|
+
padded_num_experts = d2e_count_x2_smem.shape[-1]
|
|
248
|
+
padded_top_k = t2e_routing_x2_smem.shape[-1]
|
|
249
|
+
assert padded_num_experts == align_to(num_experts, 128)
|
|
250
|
+
assert padded_top_k == align_to(top_k, 128)
|
|
251
|
+
|
|
252
|
+
t_dtype = tokens_hbm.dtype
|
|
253
|
+
t_packing = get_dtype_packing(t_dtype)
|
|
254
|
+
t_bitwidth = 32 // t_packing
|
|
255
|
+
assert a2a_g_hbm.dtype == t_dtype
|
|
256
|
+
assert w1_hbm.dtype == w2_hbm.dtype
|
|
257
|
+
|
|
258
|
+
assert bd1 % bd1c == 0
|
|
259
|
+
assert bd2 % bd2c == 0
|
|
260
|
+
assert bf % bfc == 0
|
|
261
|
+
assert hidden_size % t_packing == 0
|
|
262
|
+
assert bd1 % t_packing == 0
|
|
263
|
+
assert bd2 % t_packing == 0
|
|
264
|
+
assert bd1c % t_packing == 0
|
|
265
|
+
assert bd2c % t_packing == 0
|
|
266
|
+
|
|
267
|
+
h_per_t_packing = hidden_size // t_packing
|
|
268
|
+
assert tokens_hbm.shape[-1] == h_per_t_packing
|
|
269
|
+
bd1_per_t_packing = bd1 // t_packing
|
|
270
|
+
bd2_per_t_packing = bd2 // t_packing
|
|
271
|
+
bd1c_per_t_packing = bd1c // t_packing
|
|
272
|
+
bd2c_per_t_packing = bd2c // t_packing
|
|
273
|
+
|
|
274
|
+
if subc_quant_wsz is not None:
|
|
275
|
+
assert subc_quant_wsz % 256 == 0
|
|
276
|
+
assert bd1c_per_t_packing == subc_quant_wsz
|
|
277
|
+
assert bfc == subc_quant_wsz
|
|
278
|
+
assert bd1 % subc_quant_wsz == 0
|
|
279
|
+
assert bf % subc_quant_wsz == 0
|
|
280
|
+
assert bd1_per_t_packing % subc_quant_wsz == 0
|
|
281
|
+
assert h_per_t_packing % subc_quant_wsz == 0
|
|
282
|
+
|
|
283
|
+
num_bt = cdiv(local_num_tokens, bt)
|
|
284
|
+
num_bf = cdiv(intermediate_size, bf)
|
|
285
|
+
num_bd1 = cdiv(hidden_size, bd1)
|
|
286
|
+
num_bd2 = cdiv(hidden_size, bd2)
|
|
287
|
+
|
|
288
|
+
def get_mesh_device_id(ep_rank):
|
|
289
|
+
dp_rank = jax.lax.axis_index("data")
|
|
290
|
+
return (dp_rank, ep_rank)
|
|
291
|
+
|
|
292
|
+
def sync_barrier():
|
|
293
|
+
barrier_sem = pltpu.get_barrier_semaphore()
|
|
294
|
+
pltpu.semaphore_signal(
|
|
295
|
+
barrier_sem,
|
|
296
|
+
device_id=get_mesh_device_id(right_id),
|
|
297
|
+
device_id_type=pltpu.DeviceIdType.MESH,
|
|
298
|
+
)
|
|
299
|
+
pltpu.semaphore_wait(barrier_sem, 1)
|
|
300
|
+
|
|
301
|
+
def start_fetch_b_gating(bt_id, priority=0):
|
|
302
|
+
is_valid = jnp.logical_and(0 <= bt_id, bt_id < num_bt)
|
|
303
|
+
sz = pl.multiple_of(lax.select(is_valid, bt, 0), bt)
|
|
304
|
+
bt_sem_id = (bt_id + 2) % 2
|
|
305
|
+
b_gating_sem = local_sems.at[bt_sem_id, 0]
|
|
306
|
+
pltpu.make_async_copy(
|
|
307
|
+
src_ref=gating_hbm.at[pl.ds(bt_id * bt, sz)],
|
|
308
|
+
dst_ref=b_gating_x2_vmem.at[bt_sem_id, pl.ds(0, sz)],
|
|
309
|
+
sem=b_gating_sem,
|
|
310
|
+
).start(priority=priority)
|
|
311
|
+
|
|
312
|
+
def wait_fetch_b_gating(bt_id):
|
|
313
|
+
bt_sem_id = bt_id % 2
|
|
314
|
+
b_gating_sem = local_sems.at[bt_sem_id, 0]
|
|
315
|
+
pltpu.make_async_copy(
|
|
316
|
+
src_ref=b_gating_x2_vmem.at[bt_sem_id],
|
|
317
|
+
dst_ref=b_gating_x2_vmem.at[bt_sem_id],
|
|
318
|
+
sem=b_gating_sem,
|
|
319
|
+
).wait()
|
|
320
|
+
|
|
321
|
+
def get_top_k(input, top_k, renormalize_topk_logits):
|
|
322
|
+
assert len(input.shape) == 2, input.shape
|
|
323
|
+
input = input.astype(jnp.float32)
|
|
324
|
+
padded_k_shape = (input.shape[0], padded_top_k)
|
|
325
|
+
top_k_logits_lst = []
|
|
326
|
+
top_k_indices_lst = []
|
|
327
|
+
t2e = jnp.zeros(input.shape, dtype=jnp.int32)
|
|
328
|
+
t2e_routing = jnp.zeros(padded_k_shape, dtype=jnp.int32)
|
|
329
|
+
iota = jax.lax.broadcasted_iota(jnp.int32, input.shape, 1)
|
|
330
|
+
padded_k_iota = jax.lax.broadcasted_iota(jnp.int32, padded_k_shape, 1)
|
|
331
|
+
top_k_logits_sum = jnp.zeros(padded_k_shape, jnp.float32)
|
|
332
|
+
|
|
333
|
+
for k_id in range(top_k):
|
|
334
|
+
# TODO(jevinjiang): return both top_k values and indices in Mosaic
|
|
335
|
+
top_k_logits = jnp.broadcast_to(
|
|
336
|
+
jnp.max(input[:, :num_experts], axis=1, keepdims=True),
|
|
337
|
+
padded_k_shape,
|
|
338
|
+
).astype(input.dtype)
|
|
339
|
+
top_k_logits_lst.append(top_k_logits)
|
|
340
|
+
if renormalize_topk_logits:
|
|
341
|
+
top_k_logits_sum += top_k_logits
|
|
342
|
+
# TODO(jevinjiang): support bf16 argmax in Mosaic
|
|
343
|
+
top_k_indices = jnp.broadcast_to(
|
|
344
|
+
jnp.argmax(input[:, :num_experts], axis=1, keepdims=True),
|
|
345
|
+
padded_k_shape,
|
|
346
|
+
)
|
|
347
|
+
top_k_indices_lst.append(top_k_indices)
|
|
348
|
+
t2e_routing = jnp.where(padded_k_iota == k_id, top_k_indices,
|
|
349
|
+
t2e_routing)
|
|
350
|
+
mask = iota == broadcast_minor(top_k_indices, input.shape)
|
|
351
|
+
t2e += mask.astype(jnp.int32)
|
|
352
|
+
if k_id != top_k - 1:
|
|
353
|
+
input = jnp.where(mask, -jnp.inf, input)
|
|
354
|
+
|
|
355
|
+
if renormalize_topk_logits:
|
|
356
|
+
for k_id in range(top_k):
|
|
357
|
+
top_k_logits_lst[k_id] /= top_k_logits_sum
|
|
358
|
+
|
|
359
|
+
expert_sizes = jnp.sum(t2e, axis=0, keepdims=True)
|
|
360
|
+
expert_starts = jnp.zeros_like(expert_sizes)
|
|
361
|
+
return top_k_logits_lst, t2e_routing, expert_sizes, expert_starts
|
|
362
|
+
|
|
363
|
+
def all_reduce_metadata(bt_sem_id, t2e_routing, starts, sizes):
|
|
364
|
+
send_sem = send_sems.at[0]
|
|
365
|
+
recv_sem = recv_sems.at[0]
|
|
366
|
+
|
|
367
|
+
# All-reduce to accumulate starts and sizes and transfer to SMEM.
|
|
368
|
+
def _all_reduce_metadata(
|
|
369
|
+
t2e_routing_vmem,
|
|
370
|
+
d2e_count_vmem,
|
|
371
|
+
offsets_vmem,
|
|
372
|
+
starts_vmem,
|
|
373
|
+
sizes_vmem,
|
|
374
|
+
):
|
|
375
|
+
offsets_vmem[...] = jnp.zeros_like(offsets_vmem)
|
|
376
|
+
# TODO(jevinjiang): check how slow is VMEM -> SMEM.
|
|
377
|
+
offsets_copy = pltpu.async_copy(
|
|
378
|
+
src_ref=offsets_vmem,
|
|
379
|
+
dst_ref=expert_offsets_x2_smem.at[bt_sem_id],
|
|
380
|
+
sem=send_sem,
|
|
381
|
+
)
|
|
382
|
+
t2e_routing_vmem[...] = t2e_routing
|
|
383
|
+
t2e_routing_copy = pltpu.async_copy(
|
|
384
|
+
src_ref=t2e_routing_vmem,
|
|
385
|
+
dst_ref=t2e_routing_x2_smem.at[bt_sem_id],
|
|
386
|
+
sem=send_sem,
|
|
387
|
+
)
|
|
388
|
+
reduced_sizes = sizes
|
|
389
|
+
reduced_starts = starts
|
|
390
|
+
row_id = my_id
|
|
391
|
+
d2e_count_vmem[row_id] = sizes
|
|
392
|
+
for i in range(num_devices - 1):
|
|
393
|
+
sync_barrier()
|
|
394
|
+
# TODO(jevinjiang): we can use double buffering to improve AR if needed.
|
|
395
|
+
pltpu.async_remote_copy(
|
|
396
|
+
src_ref=d2e_count_vmem.at[row_id],
|
|
397
|
+
dst_ref=d2e_count_vmem.at[row_id],
|
|
398
|
+
send_sem=send_sem,
|
|
399
|
+
recv_sem=recv_sem,
|
|
400
|
+
device_id=get_mesh_device_id(right_id),
|
|
401
|
+
device_id_type=pltpu.DeviceIdType.MESH,
|
|
402
|
+
).wait()
|
|
403
|
+
row_id = (row_id + num_devices - 1) % num_devices
|
|
404
|
+
new_sizes = d2e_count_vmem[row_id]
|
|
405
|
+
reduced_sizes += new_sizes
|
|
406
|
+
reduced_starts += lax.select(my_id > i, new_sizes,
|
|
407
|
+
jnp.zeros_like(new_sizes))
|
|
408
|
+
starts_vmem[...] = reduced_starts
|
|
409
|
+
sizes_vmem[...] = reduced_sizes
|
|
410
|
+
|
|
411
|
+
starts_copy = pltpu.async_copy(
|
|
412
|
+
src_ref=starts_vmem,
|
|
413
|
+
dst_ref=expert_starts_x2_smem.at[bt_sem_id],
|
|
414
|
+
sem=send_sem,
|
|
415
|
+
)
|
|
416
|
+
sizes_copy = pltpu.async_copy(
|
|
417
|
+
src_ref=sizes_vmem,
|
|
418
|
+
dst_ref=expert_sizes_x2_smem.at[bt_sem_id],
|
|
419
|
+
sem=send_sem,
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
# TODO(jevinjiang): if d2e_count is too big, we can store in HBM and fetch
|
|
423
|
+
# to SMEM partially.
|
|
424
|
+
d2e_count_copy = pltpu.async_copy(
|
|
425
|
+
src_ref=d2e_count_vmem,
|
|
426
|
+
dst_ref=d2e_count_x2_smem.at[bt_sem_id],
|
|
427
|
+
sem=send_sem,
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
t2e_routing_copy.wait()
|
|
431
|
+
d2e_count_copy.wait()
|
|
432
|
+
offsets_copy.wait()
|
|
433
|
+
starts_copy.wait()
|
|
434
|
+
sizes_copy.wait()
|
|
435
|
+
|
|
436
|
+
pl.run_scoped(
|
|
437
|
+
_all_reduce_metadata,
|
|
438
|
+
pltpu.VMEM(t2e_routing_x2_smem.shape[1:],
|
|
439
|
+
t2e_routing_x2_smem.dtype),
|
|
440
|
+
pltpu.VMEM(d2e_count_x2_smem.shape[1:], d2e_count_x2_smem.dtype),
|
|
441
|
+
pltpu.VMEM(expert_offsets_x2_smem.shape[1:],
|
|
442
|
+
expert_offsets_x2_smem.dtype),
|
|
443
|
+
pltpu.VMEM(expert_starts_x2_smem.shape[1:],
|
|
444
|
+
expert_starts_x2_smem.dtype),
|
|
445
|
+
pltpu.VMEM(expert_sizes_x2_smem.shape[1:],
|
|
446
|
+
expert_sizes_x2_smem.dtype),
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
def start_a2a_scatter(bt_id, e_sem_id, local_e_id):
|
|
450
|
+
bt_sem_id = bt_id % 2
|
|
451
|
+
|
|
452
|
+
# Counting the number of remote sends from the current device.
|
|
453
|
+
send_sz = 0
|
|
454
|
+
for bt_t_id in range(bt):
|
|
455
|
+
for k_id in range(top_k):
|
|
456
|
+
e_id = t2e_routing_x2_smem[bt_sem_id, bt_t_id, k_id]
|
|
457
|
+
is_active_expert = e_id % local_num_experts == local_e_id
|
|
458
|
+
recv_id = e_id // local_num_experts
|
|
459
|
+
offset = expert_offsets_x2_smem[bt_sem_id, 0, e_id]
|
|
460
|
+
sz = lax.select(is_active_expert, 1, 0)
|
|
461
|
+
is_local = recv_id == my_id
|
|
462
|
+
local_sz = lax.select(is_local, sz, 0)
|
|
463
|
+
remote_sz = lax.select(is_local, 0, sz)
|
|
464
|
+
send_sz += remote_sz
|
|
465
|
+
expert_offsets_x2_smem[bt_sem_id, 0,
|
|
466
|
+
e_id] = (offset + local_sz + remote_sz)
|
|
467
|
+
start = expert_starts_x2_smem[bt_sem_id, 0, e_id] + offset
|
|
468
|
+
t_id = bt * bt_id + bt_t_id
|
|
469
|
+
# TODO(jevinjiang): compare the perf when using branches.
|
|
470
|
+
pltpu.make_async_copy(
|
|
471
|
+
src_ref=tokens_hbm.at[pl.ds(t_id, local_sz)],
|
|
472
|
+
dst_ref=a2a_s_x2_vmem.at[e_sem_id,
|
|
473
|
+
pl.ds(start, local_sz)],
|
|
474
|
+
sem=recv_sems.at[e_sem_id],
|
|
475
|
+
).start()
|
|
476
|
+
pltpu.make_async_remote_copy(
|
|
477
|
+
src_ref=tokens_hbm.at[pl.ds(t_id, remote_sz)],
|
|
478
|
+
dst_ref=a2a_s_x2_vmem.at[e_sem_id,
|
|
479
|
+
pl.ds(start, remote_sz)],
|
|
480
|
+
send_sem=send_sems.at[e_sem_id],
|
|
481
|
+
recv_sem=recv_sems.at[e_sem_id],
|
|
482
|
+
device_id=get_mesh_device_id(recv_id),
|
|
483
|
+
device_id_type=pltpu.DeviceIdType.MESH,
|
|
484
|
+
).start()
|
|
485
|
+
a2a_s_sends_x2_smem[e_sem_id] = send_sz
|
|
486
|
+
|
|
487
|
+
def wait_a2a_scatter_recv(bt_id, e_sem_id, local_e_id):
|
|
488
|
+
bt_sem_id = bt_id % 2
|
|
489
|
+
e_id = my_id * local_num_experts + local_e_id
|
|
490
|
+
sz = expert_sizes_x2_smem[bt_sem_id, 0, e_id]
|
|
491
|
+
pltpu.make_async_copy(
|
|
492
|
+
src_ref=a2a_s_x2_vmem.at[e_sem_id, pl.ds(0, sz)],
|
|
493
|
+
dst_ref=a2a_s_x2_vmem.at[e_sem_id, pl.ds(0, sz)],
|
|
494
|
+
sem=recv_sems.at[e_sem_id],
|
|
495
|
+
).wait()
|
|
496
|
+
|
|
497
|
+
def wait_a2a_scatter_send(bt_id, e_sem_id, local_e_id):
|
|
498
|
+
del bt_id, local_e_id
|
|
499
|
+
sz = a2a_s_sends_x2_smem[e_sem_id]
|
|
500
|
+
pltpu.make_async_copy(
|
|
501
|
+
src_ref=a2a_s_x2_vmem.at[e_sem_id, pl.ds(0, sz)],
|
|
502
|
+
dst_ref=a2a_s_x2_vmem.at[e_sem_id, pl.ds(0, sz)],
|
|
503
|
+
sem=send_sems.at[e_sem_id],
|
|
504
|
+
).wait()
|
|
505
|
+
|
|
506
|
+
def start_a2a_gather(bt_id, e_sem_id, local_e_id):
|
|
507
|
+
my_e_id = my_id * local_num_experts + local_e_id
|
|
508
|
+
bt_sem_id = bt_id % 2
|
|
509
|
+
start = 0
|
|
510
|
+
for recv_id in range(num_devices):
|
|
511
|
+
sz = d2e_count_x2_smem[bt_sem_id, recv_id, 0, my_e_id]
|
|
512
|
+
is_local = recv_id == my_id
|
|
513
|
+
local_sz = lax.select(is_local, sz, 0)
|
|
514
|
+
remote_sz = lax.select(is_local, 0, sz)
|
|
515
|
+
pltpu.make_async_copy(
|
|
516
|
+
src_ref=a2a_s_acc_x2_vmem.at[e_sem_id,
|
|
517
|
+
pl.ds(start, local_sz)],
|
|
518
|
+
dst_ref=a2a_g_hbm.at[my_e_id, pl.ds(0, local_sz)],
|
|
519
|
+
sem=a2a_gather_sem,
|
|
520
|
+
).start()
|
|
521
|
+
pltpu.make_async_remote_copy(
|
|
522
|
+
src_ref=a2a_s_acc_x2_vmem.at[e_sem_id,
|
|
523
|
+
pl.ds(start, remote_sz)],
|
|
524
|
+
dst_ref=a2a_g_hbm.at[my_e_id, pl.ds(0, remote_sz)],
|
|
525
|
+
send_sem=send_sems.at[e_sem_id],
|
|
526
|
+
recv_sem=a2a_gather_sem,
|
|
527
|
+
device_id=get_mesh_device_id(recv_id),
|
|
528
|
+
device_id_type=pltpu.DeviceIdType.MESH,
|
|
529
|
+
).start()
|
|
530
|
+
start += sz
|
|
531
|
+
|
|
532
|
+
def wait_a2a_gather_send(bt_id, e_sem_id, local_e_id):
|
|
533
|
+
my_e_id = my_id * local_num_experts + local_e_id
|
|
534
|
+
bt_sem_id = bt_id % 2
|
|
535
|
+
sz = expert_sizes_x2_smem[bt_sem_id, 0, my_e_id]
|
|
536
|
+
local_sz = d2e_count_x2_smem[bt_sem_id, my_id, 0, my_e_id]
|
|
537
|
+
remote_sz = sz - local_sz
|
|
538
|
+
is_valid = jnp.logical_and(0 <= local_e_id, local_e_id
|
|
539
|
+
< local_num_experts)
|
|
540
|
+
remote_sz = lax.select(is_valid, remote_sz, 0)
|
|
541
|
+
pltpu.make_async_copy(
|
|
542
|
+
src_ref=a2a_g_hbm.at[0, pl.ds(0, remote_sz)],
|
|
543
|
+
dst_ref=a2a_g_hbm.at[0, pl.ds(0, remote_sz)],
|
|
544
|
+
sem=send_sems.at[e_sem_id],
|
|
545
|
+
).wait()
|
|
546
|
+
|
|
547
|
+
def wait_a2a_gather_recv_all():
|
|
548
|
+
sz = top_k * bt
|
|
549
|
+
pltpu.make_async_copy(
|
|
550
|
+
src_ref=a2a_g_hbm.at[0, pl.ds(0, sz)],
|
|
551
|
+
dst_ref=a2a_g_hbm.at[0, pl.ds(0, sz)],
|
|
552
|
+
sem=a2a_gather_sem,
|
|
553
|
+
).wait()
|
|
554
|
+
|
|
555
|
+
def start_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id):
|
|
556
|
+
for p in range(t_packing):
|
|
557
|
+
offset = p * h_per_t_packing + bd1_id * bd1_per_t_packing
|
|
558
|
+
pltpu.make_async_copy(
|
|
559
|
+
src_ref=w1_hbm.at[
|
|
560
|
+
local_e_id,
|
|
561
|
+
0,
|
|
562
|
+
pl.ds(offset, bd1_per_t_packing),
|
|
563
|
+
pl.ds(bf_id * bf, bf),
|
|
564
|
+
],
|
|
565
|
+
dst_ref=b_w1_x2_vmem.at[bw1_sem_id, p],
|
|
566
|
+
sem=local_sems.at[bw1_sem_id, 1],
|
|
567
|
+
).start()
|
|
568
|
+
if w1_scale_hbm is not None:
|
|
569
|
+
assert subc_quant_wsz is not None
|
|
570
|
+
pltpu.make_async_copy(
|
|
571
|
+
src_ref=w1_scale_hbm.at[
|
|
572
|
+
local_e_id,
|
|
573
|
+
0,
|
|
574
|
+
pl.ds(
|
|
575
|
+
offset // subc_quant_wsz,
|
|
576
|
+
bd1_per_t_packing // subc_quant_wsz,
|
|
577
|
+
),
|
|
578
|
+
pl.ds(0, 1),
|
|
579
|
+
pl.ds(bf_id * bf, bf),
|
|
580
|
+
],
|
|
581
|
+
dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id, p],
|
|
582
|
+
sem=local_sems.at[bw1_sem_id, 1],
|
|
583
|
+
).start()
|
|
584
|
+
if b1_hbm is not None and bd1_id == 0:
|
|
585
|
+
pltpu.make_async_copy(
|
|
586
|
+
src_ref=b1_hbm.at[local_e_id, 0,
|
|
587
|
+
pl.ds(0, 1),
|
|
588
|
+
pl.ds(bf_id * bf, bf)],
|
|
589
|
+
dst_ref=b_b1_x2_vmem.at[bf_id % 2],
|
|
590
|
+
sem=local_sems.at[bw1_sem_id, 1],
|
|
591
|
+
).start()
|
|
592
|
+
|
|
593
|
+
def start_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id):
|
|
594
|
+
for p in range(t_packing):
|
|
595
|
+
offset = p * h_per_t_packing + bd2_id * bd2_per_t_packing
|
|
596
|
+
pltpu.make_async_copy(
|
|
597
|
+
src_ref=w2_hbm.at[
|
|
598
|
+
local_e_id,
|
|
599
|
+
pl.ds(bf_id * bf, bf),
|
|
600
|
+
pl.ds(offset, bd2_per_t_packing),
|
|
601
|
+
],
|
|
602
|
+
dst_ref=b_w2_x2_vmem.at[bw2_sem_id, p],
|
|
603
|
+
sem=local_sems.at[bw2_sem_id, 2],
|
|
604
|
+
).start()
|
|
605
|
+
if w2_scale_hbm is not None:
|
|
606
|
+
assert subc_quant_wsz is not None
|
|
607
|
+
pltpu.make_async_copy(
|
|
608
|
+
src_ref=w2_scale_hbm.at[
|
|
609
|
+
local_e_id,
|
|
610
|
+
pl.ds(bf_id * bf // subc_quant_wsz, bf //
|
|
611
|
+
subc_quant_wsz),
|
|
612
|
+
pl.ds(0, 1),
|
|
613
|
+
pl.ds(offset, bd2_per_t_packing),
|
|
614
|
+
],
|
|
615
|
+
dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id, p],
|
|
616
|
+
sem=local_sems.at[bw2_sem_id, 2],
|
|
617
|
+
).start()
|
|
618
|
+
if b2_hbm is not None and bf_id == 0:
|
|
619
|
+
pltpu.make_async_copy(
|
|
620
|
+
src_ref=b2_hbm.at[local_e_id,
|
|
621
|
+
pl.ds(0, 1),
|
|
622
|
+
pl.ds(offset, bd2_per_t_packing)],
|
|
623
|
+
dst_ref=b_b2_x2_vmem.at[bd2_id % 2, p],
|
|
624
|
+
sem=local_sems.at[bw2_sem_id, 2],
|
|
625
|
+
).start()
|
|
626
|
+
|
|
627
|
+
def start_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id):
|
|
628
|
+
for p in range(t_packing):
|
|
629
|
+
offset = p * h_per_t_packing + bd3_id * bd1_per_t_packing
|
|
630
|
+
pltpu.make_async_copy(
|
|
631
|
+
src_ref=w1_hbm.at[
|
|
632
|
+
local_e_id,
|
|
633
|
+
1,
|
|
634
|
+
pl.ds(offset, bd1_per_t_packing),
|
|
635
|
+
pl.ds(bf_id * bf, bf),
|
|
636
|
+
],
|
|
637
|
+
dst_ref=b_w3_x2_vmem.at[bw3_sem_id, p],
|
|
638
|
+
sem=local_sems.at[bw3_sem_id, 3],
|
|
639
|
+
).start()
|
|
640
|
+
if w1_scale_hbm is not None:
|
|
641
|
+
assert subc_quant_wsz is not None
|
|
642
|
+
pltpu.make_async_copy(
|
|
643
|
+
src_ref=w1_scale_hbm.at[
|
|
644
|
+
local_e_id,
|
|
645
|
+
1,
|
|
646
|
+
pl.ds(
|
|
647
|
+
offset // subc_quant_wsz,
|
|
648
|
+
bd1_per_t_packing // subc_quant_wsz,
|
|
649
|
+
),
|
|
650
|
+
pl.ds(0, 1),
|
|
651
|
+
pl.ds(bf_id * bf, bf),
|
|
652
|
+
],
|
|
653
|
+
dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id, p],
|
|
654
|
+
sem=local_sems.at[bw3_sem_id, 3],
|
|
655
|
+
).start()
|
|
656
|
+
if b1_hbm is not None and bd3_id == 0:
|
|
657
|
+
pltpu.make_async_copy(
|
|
658
|
+
src_ref=b1_hbm.at[local_e_id, 1,
|
|
659
|
+
pl.ds(0, 1),
|
|
660
|
+
pl.ds(bf_id * bf, bf)],
|
|
661
|
+
dst_ref=b_b3_x2_vmem.at[bf_id % 2],
|
|
662
|
+
sem=local_sems.at[bw3_sem_id, 3],
|
|
663
|
+
).start()
|
|
664
|
+
|
|
665
|
+
def wait_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id):
|
|
666
|
+
del local_e_id
|
|
667
|
+
pltpu.make_async_copy(
|
|
668
|
+
src_ref=b_w1_x2_vmem.at[bw1_sem_id],
|
|
669
|
+
dst_ref=b_w1_x2_vmem.at[bw1_sem_id],
|
|
670
|
+
sem=local_sems.at[bw1_sem_id, 1],
|
|
671
|
+
).wait()
|
|
672
|
+
if w1_scale_hbm is not None:
|
|
673
|
+
pltpu.make_async_copy(
|
|
674
|
+
src_ref=b_w1_scale_x2_vmem.at[bw1_sem_id],
|
|
675
|
+
dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id],
|
|
676
|
+
sem=local_sems.at[bw1_sem_id, 1],
|
|
677
|
+
).wait()
|
|
678
|
+
if b1_hbm is not None and bd1_id == 0:
|
|
679
|
+
pltpu.make_async_copy(
|
|
680
|
+
src_ref=b_b1_x2_vmem.at[bf_id % 2],
|
|
681
|
+
dst_ref=b_b1_x2_vmem.at[bf_id % 2],
|
|
682
|
+
sem=local_sems.at[bw1_sem_id, 1],
|
|
683
|
+
).wait()
|
|
684
|
+
|
|
685
|
+
def wait_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id):
|
|
686
|
+
del local_e_id
|
|
687
|
+
pltpu.make_async_copy(
|
|
688
|
+
src_ref=b_w2_x2_vmem.at[bw2_sem_id],
|
|
689
|
+
dst_ref=b_w2_x2_vmem.at[bw2_sem_id],
|
|
690
|
+
sem=local_sems.at[bw2_sem_id, 2],
|
|
691
|
+
).wait()
|
|
692
|
+
if w2_scale_hbm is not None:
|
|
693
|
+
pltpu.make_async_copy(
|
|
694
|
+
src_ref=b_w2_scale_x2_vmem.at[bw2_sem_id],
|
|
695
|
+
dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id],
|
|
696
|
+
sem=local_sems.at[bw2_sem_id, 2],
|
|
697
|
+
).wait()
|
|
698
|
+
if b2_hbm is not None and bf_id == 0:
|
|
699
|
+
pltpu.make_async_copy(
|
|
700
|
+
src_ref=b_b2_x2_vmem.at[bd2_id % 2],
|
|
701
|
+
dst_ref=b_b2_x2_vmem.at[bd2_id % 2],
|
|
702
|
+
sem=local_sems.at[bw2_sem_id, 2],
|
|
703
|
+
).wait()
|
|
704
|
+
|
|
705
|
+
def wait_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id):
|
|
706
|
+
del local_e_id
|
|
707
|
+
pltpu.make_async_copy(
|
|
708
|
+
src_ref=b_w3_x2_vmem.at[bw3_sem_id],
|
|
709
|
+
dst_ref=b_w3_x2_vmem.at[bw3_sem_id],
|
|
710
|
+
sem=local_sems.at[bw3_sem_id, 3],
|
|
711
|
+
).wait()
|
|
712
|
+
if w1_scale_hbm is not None:
|
|
713
|
+
pltpu.make_async_copy(
|
|
714
|
+
src_ref=b_w3_scale_x2_vmem.at[bw3_sem_id],
|
|
715
|
+
dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id],
|
|
716
|
+
sem=local_sems.at[bw3_sem_id, 3],
|
|
717
|
+
).wait()
|
|
718
|
+
if b1_hbm is not None and bd3_id == 0:
|
|
719
|
+
pltpu.make_async_copy(
|
|
720
|
+
src_ref=b_b3_x2_vmem.at[bf_id % 2],
|
|
721
|
+
dst_ref=b_b3_x2_vmem.at[bf_id % 2],
|
|
722
|
+
sem=local_sems.at[bw3_sem_id, 3],
|
|
723
|
+
).wait()
|
|
724
|
+
|
|
725
|
+
def start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, bd2_id):
|
|
726
|
+
next_bd1_id = bd1_id + 1
|
|
727
|
+
next_bd2_id = bd2_id + 1
|
|
728
|
+
next_sem_id = (bw_sem_id + 1) % 2
|
|
729
|
+
|
|
730
|
+
if bf_id >= num_bf:
|
|
731
|
+
return
|
|
732
|
+
if next_bd1_id < num_bd1:
|
|
733
|
+
start_fetch_bw1(local_e_id, next_sem_id, bf_id, next_bd1_id)
|
|
734
|
+
start_fetch_bw3(local_e_id, next_sem_id, bf_id, next_bd1_id)
|
|
735
|
+
elif next_bd1_id == num_bd1:
|
|
736
|
+
start_fetch_bw2(local_e_id, next_sem_id, bf_id, 0)
|
|
737
|
+
elif next_bd2_id < num_bd2:
|
|
738
|
+
start_fetch_bw2(local_e_id, next_sem_id, bf_id, next_bd2_id)
|
|
739
|
+
elif next_bd2_id == num_bd2:
|
|
740
|
+
start_fetch_next_bw(local_e_id, bw_sem_id, bf_id + 1, -1, -1)
|
|
741
|
+
else:
|
|
742
|
+
raise RuntimeError("Unreachable")
|
|
743
|
+
|
|
744
|
+
def dynamic_ffn1(
|
|
745
|
+
t_b32_vmem,
|
|
746
|
+
w1_vmem,
|
|
747
|
+
w1_scale_vmem,
|
|
748
|
+
b1_vmem,
|
|
749
|
+
w3_vmem,
|
|
750
|
+
w3_scale_vmem,
|
|
751
|
+
b3_vmem,
|
|
752
|
+
acc1_vmem,
|
|
753
|
+
acc3_vmem,
|
|
754
|
+
dyn_sz,
|
|
755
|
+
should_init,
|
|
756
|
+
):
|
|
757
|
+
assert t_b32_vmem.shape == (bt * num_devices, bd1 // t_packing)
|
|
758
|
+
assert w1_vmem.shape == w3_vmem.shape == (t_packing, bd1_per_t_packing,
|
|
759
|
+
bf)
|
|
760
|
+
assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf)
|
|
761
|
+
assert bd1 % (t_packing * 128) == 0, (bd1, t_packing)
|
|
762
|
+
assert bd1c % (t_packing * 128) == 0, (bd1c, t_packing)
|
|
763
|
+
if w1_scale_vmem is not None:
|
|
764
|
+
assert w1_scale_vmem.shape == (
|
|
765
|
+
t_packing,
|
|
766
|
+
bd1_per_t_packing // subc_quant_wsz,
|
|
767
|
+
1,
|
|
768
|
+
bf,
|
|
769
|
+
)
|
|
770
|
+
assert bd1c_per_t_packing == subc_quant_wsz
|
|
771
|
+
if w3_scale_vmem is not None:
|
|
772
|
+
assert w3_scale_vmem.shape == (
|
|
773
|
+
t_packing,
|
|
774
|
+
bd1_per_t_packing // subc_quant_wsz,
|
|
775
|
+
1,
|
|
776
|
+
bf,
|
|
777
|
+
)
|
|
778
|
+
assert bd1c_per_t_packing == subc_quant_wsz
|
|
779
|
+
|
|
780
|
+
num_loops = cdiv(dyn_sz, btc)
|
|
781
|
+
repack_ty = jnp.dtype(f"int{t_bitwidth}")
|
|
782
|
+
|
|
783
|
+
def body(btc_id, _):
|
|
784
|
+
for bd1c_id in range(cdiv(bd1, bd1c)):
|
|
785
|
+
t_b32 = t_b32_vmem[
|
|
786
|
+
pl.ds(btc_id * btc, btc),
|
|
787
|
+
pl.ds(bd1c_id * bd1c_per_t_packing, bd1c_per_t_packing),
|
|
788
|
+
]
|
|
789
|
+
for p_id in range(t_packing):
|
|
790
|
+
t = pltpu.bitcast(t_b32.astype(repack_ty), t_dtype)
|
|
791
|
+
t_b32 = t_b32 >> t_bitwidth
|
|
792
|
+
for bfc_id in range(cdiv(bf, bfc)):
|
|
793
|
+
w_slices = (
|
|
794
|
+
p_id,
|
|
795
|
+
pl.ds(bd1c_id * bd1c_per_t_packing,
|
|
796
|
+
bd1c_per_t_packing),
|
|
797
|
+
pl.ds(bfc_id * bfc, bfc),
|
|
798
|
+
)
|
|
799
|
+
w1 = w1_vmem[*w_slices]
|
|
800
|
+
acc1 = jnp.dot(t,
|
|
801
|
+
w1,
|
|
802
|
+
preferred_element_type=jnp.float32)
|
|
803
|
+
|
|
804
|
+
if w1_scale_vmem is not None:
|
|
805
|
+
w1_scale_slices = (
|
|
806
|
+
p_id,
|
|
807
|
+
bd1c_id,
|
|
808
|
+
pl.ds(0, 1),
|
|
809
|
+
pl.ds(bfc_id * bfc, bfc),
|
|
810
|
+
)
|
|
811
|
+
# TODO(jevinjiang): can use mosaic to load with stride 0.
|
|
812
|
+
w1_scale = jnp.broadcast_to(
|
|
813
|
+
w1_scale_vmem[*w1_scale_slices], acc1.shape)
|
|
814
|
+
acc1 *= w1_scale
|
|
815
|
+
|
|
816
|
+
w3 = w3_vmem[*w_slices]
|
|
817
|
+
|
|
818
|
+
acc3 = jnp.dot(t,
|
|
819
|
+
w3,
|
|
820
|
+
preferred_element_type=jnp.float32)
|
|
821
|
+
|
|
822
|
+
if w3_scale_vmem is not None:
|
|
823
|
+
w3_scale_slices = (
|
|
824
|
+
p_id,
|
|
825
|
+
bd1c_id,
|
|
826
|
+
pl.ds(0, 1),
|
|
827
|
+
pl.ds(bfc_id * bfc, bfc),
|
|
828
|
+
)
|
|
829
|
+
w3_scale = jnp.broadcast_to(
|
|
830
|
+
w3_scale_vmem[*w3_scale_slices], acc3.shape)
|
|
831
|
+
acc3 *= w3_scale
|
|
832
|
+
|
|
833
|
+
acc_slices = (pl.ds(btc_id * btc,
|
|
834
|
+
btc), pl.ds(bfc_id * bfc, bfc))
|
|
835
|
+
if should_init and p_id == bd1c_id == 0:
|
|
836
|
+
if b1_vmem is not None:
|
|
837
|
+
b1_scale_slices = (
|
|
838
|
+
pl.ds(0, 1),
|
|
839
|
+
pl.ds(bfc_id * bfc, bfc),
|
|
840
|
+
)
|
|
841
|
+
b1 = jnp.broadcast_to(
|
|
842
|
+
b1_vmem[*b1_scale_slices], acc1.shape)
|
|
843
|
+
acc1 += b1
|
|
844
|
+
if b3_vmem is not None:
|
|
845
|
+
b3_scale_slices = (
|
|
846
|
+
pl.ds(0, 1),
|
|
847
|
+
pl.ds(bfc_id * bfc, bfc),
|
|
848
|
+
)
|
|
849
|
+
b3 = jnp.broadcast_to(
|
|
850
|
+
b3_vmem[*b3_scale_slices], acc1.shape)
|
|
851
|
+
acc3 += b3
|
|
852
|
+
|
|
853
|
+
acc1_vmem[*acc_slices] = acc1
|
|
854
|
+
acc3_vmem[*acc_slices] = acc3
|
|
855
|
+
else:
|
|
856
|
+
acc1_vmem[*acc_slices] += acc1
|
|
857
|
+
acc3_vmem[*acc_slices] += acc3
|
|
858
|
+
|
|
859
|
+
lax.fori_loop(0, num_loops, body, None)
|
|
860
|
+
|
|
861
|
+
def dynamic_ffn2(
|
|
862
|
+
acc1_vmem,
|
|
863
|
+
acc3_vmem,
|
|
864
|
+
w2_vmem,
|
|
865
|
+
w2_scale_vmem,
|
|
866
|
+
b2_vmem,
|
|
867
|
+
res_b32_vmem,
|
|
868
|
+
dyn_sz,
|
|
869
|
+
should_init,
|
|
870
|
+
):
|
|
871
|
+
assert res_b32_vmem.shape == (bt * num_devices, bd2_per_t_packing)
|
|
872
|
+
assert w2_vmem.shape == (t_packing, bf, bd2_per_t_packing)
|
|
873
|
+
assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf)
|
|
874
|
+
assert bd2 % (t_packing * 128) == 0, (bd2, t_packing)
|
|
875
|
+
assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing)
|
|
876
|
+
assert t_dtype in (jnp.float32, jnp.bfloat16)
|
|
877
|
+
|
|
878
|
+
if w2_scale_vmem is not None:
|
|
879
|
+
assert w2_scale_vmem.shape == (
|
|
880
|
+
t_packing,
|
|
881
|
+
bf // subc_quant_wsz,
|
|
882
|
+
1,
|
|
883
|
+
bd2_per_t_packing,
|
|
884
|
+
)
|
|
885
|
+
assert bfc == subc_quant_wsz
|
|
886
|
+
|
|
887
|
+
num_loops = cdiv(dyn_sz, btc)
|
|
888
|
+
assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing)
|
|
889
|
+
|
|
890
|
+
def body(btc_id, _):
|
|
891
|
+
for bd2c_id in range(cdiv(bd2, bd2c)):
|
|
892
|
+
res_lst = []
|
|
893
|
+
for p_id in range(t_packing):
|
|
894
|
+
res = jnp.zeros((btc, bd2c_per_t_packing),
|
|
895
|
+
dtype=jnp.float32)
|
|
896
|
+
|
|
897
|
+
if b2_vmem is not None and should_init:
|
|
898
|
+
b2_scale_slices = (
|
|
899
|
+
p_id,
|
|
900
|
+
pl.ds(0, 1),
|
|
901
|
+
pl.ds(bd2c_id * bd2c_per_t_packing,
|
|
902
|
+
bd2c_per_t_packing),
|
|
903
|
+
)
|
|
904
|
+
b2 = jnp.broadcast_to(b2_vmem[*b2_scale_slices],
|
|
905
|
+
res.shape)
|
|
906
|
+
res += b2
|
|
907
|
+
|
|
908
|
+
for bfc_id in range(cdiv(bf, bfc)):
|
|
909
|
+
acc_slices = (pl.ds(btc_id * btc,
|
|
910
|
+
btc), pl.ds(bfc_id * bfc, bfc))
|
|
911
|
+
acc1 = acc1_vmem[*acc_slices]
|
|
912
|
+
acc3 = acc3_vmem[*acc_slices]
|
|
913
|
+
act = activation_fn(acc1, acc3, act_fn)
|
|
914
|
+
w2 = w2_vmem[
|
|
915
|
+
p_id,
|
|
916
|
+
pl.ds(bfc_id * bfc, bfc),
|
|
917
|
+
pl.ds(bd2c_id *
|
|
918
|
+
bd2c_per_t_packing, bd2c_per_t_packing),
|
|
919
|
+
]
|
|
920
|
+
acc = jnp.dot(act,
|
|
921
|
+
w2,
|
|
922
|
+
preferred_element_type=jnp.float32)
|
|
923
|
+
if w2_scale_vmem is not None:
|
|
924
|
+
w2_scale_slices = (
|
|
925
|
+
p_id,
|
|
926
|
+
bfc_id,
|
|
927
|
+
pl.ds(0, 1),
|
|
928
|
+
pl.ds(bd2c_id * bd2c_per_t_packing,
|
|
929
|
+
bd2c_per_t_packing),
|
|
930
|
+
)
|
|
931
|
+
w2_scale = jnp.broadcast_to(
|
|
932
|
+
w2_scale_vmem[*w2_scale_slices], acc.shape)
|
|
933
|
+
acc *= w2_scale
|
|
934
|
+
res += acc
|
|
935
|
+
res = pltpu.bitcast(res, jnp.uint32)
|
|
936
|
+
if t_packing == 2:
|
|
937
|
+
res = res >> 16 << (16 * p_id)
|
|
938
|
+
else:
|
|
939
|
+
assert t_packing == 1
|
|
940
|
+
res_lst.append(res)
|
|
941
|
+
res = res_lst[0]
|
|
942
|
+
# TODO(jevinjiang): use interleaved packing when it is exposed to Pallas
|
|
943
|
+
for i in range(1, t_packing):
|
|
944
|
+
res |= res_lst[i]
|
|
945
|
+
sliced_res_vmem = res_b32_vmem.at[
|
|
946
|
+
pl.ds(btc_id * btc, btc),
|
|
947
|
+
pl.ds(bd2c_id * bd2c_per_t_packing, bd2c_per_t_packing),
|
|
948
|
+
]
|
|
949
|
+
if should_init:
|
|
950
|
+
sliced_res_vmem[...] = res
|
|
951
|
+
else:
|
|
952
|
+
sliced_res_vmem[...] = pltpu.bitcast(
|
|
953
|
+
sliced_res_vmem.bitcast(t_dtype)[...] +
|
|
954
|
+
pltpu.bitcast(res, t_dtype),
|
|
955
|
+
sliced_res_vmem.dtype,
|
|
956
|
+
)
|
|
957
|
+
|
|
958
|
+
lax.fori_loop(0, num_loops, body, None)
|
|
959
|
+
|
|
960
|
+
def expert_ffn(bt_id, e_sem_id, local_e_id):
|
|
961
|
+
bt_sem_id = bt_id % 2
|
|
962
|
+
bw_sem_id = 0
|
|
963
|
+
# start_fetch_bw1(local_e_id, bw_sem_id, 0, 0)
|
|
964
|
+
# start_fetch_bw3(local_e_id, bw_sem_id, 0, 0)
|
|
965
|
+
a2a_s_b32_vmem = (a2a_s_x2_vmem.bitcast(jnp.uint32).reshape(
|
|
966
|
+
2, bt * num_devices, hidden_size // t_packing).at[e_sem_id])
|
|
967
|
+
a2a_s_acc_b32_vmem = (a2a_s_acc_x2_vmem.bitcast(jnp.uint32).reshape(
|
|
968
|
+
2, bt * num_devices, hidden_size // t_packing).at[e_sem_id])
|
|
969
|
+
b_acc_vmem_2d = b_acc_vmem.reshape(bt * num_devices, bf * 2)
|
|
970
|
+
b_acc1_vmem = b_acc_vmem_2d.at[:, :bf]
|
|
971
|
+
b_acc3_vmem = b_acc_vmem_2d.at[:, bf:]
|
|
972
|
+
|
|
973
|
+
e_id = my_id * local_num_experts + local_e_id
|
|
974
|
+
dyn_sz = expert_sizes_x2_smem[bt_sem_id, 0, e_id]
|
|
975
|
+
|
|
976
|
+
bd1_per_t_packing = bd1 // t_packing
|
|
977
|
+
bd2_per_t_packing = bd2 // t_packing
|
|
978
|
+
|
|
979
|
+
for bf_id in range(num_bf):
|
|
980
|
+
for bd1_id in range(num_bd1):
|
|
981
|
+
start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, 0)
|
|
982
|
+
w1_scale_vmem = (None if b_w1_scale_x2_vmem is None else
|
|
983
|
+
b_w1_scale_x2_vmem.at[bw_sem_id])
|
|
984
|
+
w3_scale_vmem = (None if b_w3_scale_x2_vmem is None else
|
|
985
|
+
b_w3_scale_x2_vmem.at[bw_sem_id])
|
|
986
|
+
b1_vmem = None if b_b1_x2_vmem is None else b_b1_x2_vmem.at[
|
|
987
|
+
bf_id % 2]
|
|
988
|
+
b3_vmem = None if b_b3_x2_vmem is None else b_b3_x2_vmem.at[
|
|
989
|
+
bf_id % 2]
|
|
990
|
+
wait_fetch_bw1(local_e_id, bw_sem_id, bf_id, bd1_id)
|
|
991
|
+
wait_fetch_bw3(local_e_id, bw_sem_id, bf_id, bd1_id)
|
|
992
|
+
|
|
993
|
+
dynamic_ffn1(
|
|
994
|
+
t_b32_vmem=a2a_s_b32_vmem.at[
|
|
995
|
+
...,
|
|
996
|
+
pl.ds(bd1_id * bd1_per_t_packing, bd1_per_t_packing)],
|
|
997
|
+
w1_vmem=b_w1_x2_vmem.at[bw_sem_id],
|
|
998
|
+
w1_scale_vmem=w1_scale_vmem,
|
|
999
|
+
b1_vmem=b1_vmem,
|
|
1000
|
+
w3_vmem=b_w3_x2_vmem.at[bw_sem_id],
|
|
1001
|
+
w3_scale_vmem=w3_scale_vmem,
|
|
1002
|
+
b3_vmem=b3_vmem,
|
|
1003
|
+
acc1_vmem=b_acc1_vmem,
|
|
1004
|
+
acc3_vmem=b_acc3_vmem,
|
|
1005
|
+
dyn_sz=dyn_sz,
|
|
1006
|
+
should_init=(bd1_id == 0),
|
|
1007
|
+
)
|
|
1008
|
+
bw_sem_id = (bw_sem_id + 1) % 2
|
|
1009
|
+
|
|
1010
|
+
for bd2_id in range(num_bd2):
|
|
1011
|
+
start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, num_bd1,
|
|
1012
|
+
bd2_id)
|
|
1013
|
+
wait_fetch_bw2(local_e_id, bw_sem_id, bf_id, bd2_id)
|
|
1014
|
+
if bf_id == bd2_id == 0:
|
|
1015
|
+
wait_a2a_gather_send(bt_id, e_sem_id, local_e_id - 2)
|
|
1016
|
+
|
|
1017
|
+
w2_scale_vmem = (None if b_w2_scale_x2_vmem is None else
|
|
1018
|
+
b_w2_scale_x2_vmem.at[bw_sem_id])
|
|
1019
|
+
b2_vmem = None if b_b2_x2_vmem is None else b_b2_x2_vmem.at[
|
|
1020
|
+
bd2_id % 2]
|
|
1021
|
+
dynamic_ffn2(
|
|
1022
|
+
acc1_vmem=b_acc1_vmem,
|
|
1023
|
+
acc3_vmem=b_acc3_vmem,
|
|
1024
|
+
w2_vmem=b_w2_x2_vmem.at[bw_sem_id],
|
|
1025
|
+
w2_scale_vmem=w2_scale_vmem,
|
|
1026
|
+
b2_vmem=b2_vmem,
|
|
1027
|
+
res_b32_vmem=a2a_s_acc_b32_vmem.at[
|
|
1028
|
+
...,
|
|
1029
|
+
pl.ds(bd2_id * bd2_per_t_packing, bd2_per_t_packing)],
|
|
1030
|
+
dyn_sz=dyn_sz,
|
|
1031
|
+
should_init=(bf_id == 0),
|
|
1032
|
+
)
|
|
1033
|
+
bw_sem_id = (bw_sem_id + 1) % 2
|
|
1034
|
+
|
|
1035
|
+
def bt_acc(bt_id, top_k_logits_lst):
|
|
1036
|
+
bt_sem_id = bt_id % 2
|
|
1037
|
+
for bt_t_id in range(bt):
|
|
1038
|
+
for k_id in range(top_k):
|
|
1039
|
+
e_id = t2e_routing_x2_smem[bt_sem_id, bt_t_id, k_id]
|
|
1040
|
+
offset = expert_offsets_x2_smem[bt_sem_id, 1, e_id]
|
|
1041
|
+
expert_offsets_x2_smem[bt_sem_id, 1, e_id] = offset + 1
|
|
1042
|
+
pltpu.make_async_copy(
|
|
1043
|
+
src_ref=a2a_g_hbm.at[e_id, pl.ds(offset, 1)],
|
|
1044
|
+
dst_ref=a2a_g_acc_vmem.at[k_id, pl.ds(bt_t_id, 1)],
|
|
1045
|
+
sem=a2a_acc_sem,
|
|
1046
|
+
).start()
|
|
1047
|
+
pltpu.make_async_copy(
|
|
1048
|
+
src_ref=a2a_g_acc_vmem,
|
|
1049
|
+
dst_ref=a2a_g_acc_vmem,
|
|
1050
|
+
sem=a2a_acc_sem,
|
|
1051
|
+
).wait()
|
|
1052
|
+
output = None
|
|
1053
|
+
for k_id in range(top_k):
|
|
1054
|
+
acc = a2a_g_acc_vmem[k_id].reshape(bt, hidden_size)
|
|
1055
|
+
logits = broadcast_minor(top_k_logits_lst[k_id], acc.shape)
|
|
1056
|
+
acc *= logits
|
|
1057
|
+
if output is None:
|
|
1058
|
+
output = acc
|
|
1059
|
+
else:
|
|
1060
|
+
output += acc
|
|
1061
|
+
assert output is not None
|
|
1062
|
+
return output.astype(output_hbm.dtype)
|
|
1063
|
+
|
|
1064
|
+
def start_send_bo(bt_id, priority=0):
|
|
1065
|
+
bt_sem_id = bt_id % 2
|
|
1066
|
+
b_output_sem = local_sems.at[bt_sem_id, 4]
|
|
1067
|
+
pltpu.make_async_copy(
|
|
1068
|
+
src_ref=b_output_x2_vmem.at[bt_sem_id],
|
|
1069
|
+
dst_ref=output_hbm.at[pl.ds(bt_id * bt, bt)],
|
|
1070
|
+
sem=b_output_sem,
|
|
1071
|
+
).start(priority=priority)
|
|
1072
|
+
|
|
1073
|
+
def wait_send_bo(bt_id):
|
|
1074
|
+
is_valid = jnp.logical_and(0 <= bt_id, bt_id < num_bt)
|
|
1075
|
+
sz = pl.multiple_of(lax.select(is_valid, bt, 0), bt)
|
|
1076
|
+
bt_sem_id = (bt_id + 2) % 2
|
|
1077
|
+
b_output_sem = local_sems.at[bt_sem_id, 4]
|
|
1078
|
+
pltpu.make_async_copy(
|
|
1079
|
+
src_ref=output_hbm.at[pl.ds(0, sz)],
|
|
1080
|
+
dst_ref=output_hbm.at[pl.ds(0, sz)],
|
|
1081
|
+
sem=b_output_sem,
|
|
1082
|
+
).wait()
|
|
1083
|
+
|
|
1084
|
+
### ------- Kernel start ------- ###
|
|
1085
|
+
start_fetch_b_gating(bt_id=0)
|
|
1086
|
+
|
|
1087
|
+
def run_per_bt(bt_id, e_sem_id):
|
|
1088
|
+
bt_sem_id = bt_id % 2
|
|
1089
|
+
next_bt_id = bt_id + 1
|
|
1090
|
+
start_fetch_b_gating(next_bt_id)
|
|
1091
|
+
wait_fetch_b_gating(bt_id)
|
|
1092
|
+
|
|
1093
|
+
b_gating = b_gating_x2_vmem[bt_sem_id]
|
|
1094
|
+
b_gating_score = jax.nn.softmax(b_gating, axis=-1)
|
|
1095
|
+
top_k_logits_lst, t2e_routing, expert_sizes, expert_starts = get_top_k(
|
|
1096
|
+
b_gating_score, top_k, renormalize_topk_logits)
|
|
1097
|
+
|
|
1098
|
+
all_reduce_metadata(bt_sem_id, t2e_routing, expert_starts,
|
|
1099
|
+
expert_sizes)
|
|
1100
|
+
sync_barrier()
|
|
1101
|
+
|
|
1102
|
+
# Start a2a scatter for first active expert.
|
|
1103
|
+
start_a2a_scatter(bt_id=bt_id, e_sem_id=e_sem_id, local_e_id=0)
|
|
1104
|
+
|
|
1105
|
+
def run_per_expert(local_e_id, e_sem_id):
|
|
1106
|
+
sync_barrier()
|
|
1107
|
+
|
|
1108
|
+
# Prefetch weights for CURRENT active expert.
|
|
1109
|
+
# TODO(jevinjiang): It is hard to prefetch weights in previous iteration
|
|
1110
|
+
# because the expert_ffn keeps overwriting the buffers. Triple buffering
|
|
1111
|
+
# could resolve this but it takes more VMEM scratch. Need further
|
|
1112
|
+
# experiment on this.
|
|
1113
|
+
start_fetch_bw1(local_e_id, bw1_sem_id=0, bf_id=0, bd1_id=0)
|
|
1114
|
+
start_fetch_bw3(local_e_id, bw3_sem_id=0, bf_id=0, bd3_id=0)
|
|
1115
|
+
|
|
1116
|
+
# Next ids.
|
|
1117
|
+
next_e_sem_id = lax.select(e_sem_id == 0, 1, 0)
|
|
1118
|
+
next_local_e_id = local_e_id + 1
|
|
1119
|
+
|
|
1120
|
+
# Start a2a scatter for NEXT active expert.
|
|
1121
|
+
@pl.when(next_local_e_id < local_num_experts)
|
|
1122
|
+
def _():
|
|
1123
|
+
start_a2a_scatter(bt_id, next_e_sem_id, next_local_e_id)
|
|
1124
|
+
|
|
1125
|
+
# Wait a2a scatter for CURRENT active expert.
|
|
1126
|
+
wait_a2a_scatter_recv(bt_id, e_sem_id, local_e_id)
|
|
1127
|
+
|
|
1128
|
+
# Perform FFN for CURRENT active expert.
|
|
1129
|
+
expert_ffn(bt_id, e_sem_id, local_e_id)
|
|
1130
|
+
|
|
1131
|
+
# Start a2a gather to send back tokens for CURRENT active expert.
|
|
1132
|
+
start_a2a_gather(bt_id, e_sem_id, local_e_id)
|
|
1133
|
+
|
|
1134
|
+
# A must-wait before next sync_barrier.
|
|
1135
|
+
wait_a2a_scatter_send(bt_id, e_sem_id, local_e_id)
|
|
1136
|
+
return next_e_sem_id
|
|
1137
|
+
|
|
1138
|
+
e_sem_id = lax.fori_loop(0,
|
|
1139
|
+
local_num_experts,
|
|
1140
|
+
run_per_expert,
|
|
1141
|
+
e_sem_id,
|
|
1142
|
+
unroll=False)
|
|
1143
|
+
|
|
1144
|
+
# Wait to receive a2a gather for ALL experts.
|
|
1145
|
+
wait_a2a_gather_recv_all()
|
|
1146
|
+
|
|
1147
|
+
# Accumulate results for current batch.
|
|
1148
|
+
output = bt_acc(bt_id, top_k_logits_lst)
|
|
1149
|
+
|
|
1150
|
+
# Make sure it is safe to overwrite output buffer.
|
|
1151
|
+
wait_send_bo(bt_id=bt_id - 2)
|
|
1152
|
+
b_output_x2_vmem[bt_sem_id] = output
|
|
1153
|
+
|
|
1154
|
+
start_send_bo(bt_id)
|
|
1155
|
+
|
|
1156
|
+
wait_a2a_gather_send(
|
|
1157
|
+
bt_id,
|
|
1158
|
+
e_sem_id=e_sem_id,
|
|
1159
|
+
local_e_id=local_num_experts - 2,
|
|
1160
|
+
)
|
|
1161
|
+
wait_a2a_gather_send(
|
|
1162
|
+
bt_id,
|
|
1163
|
+
e_sem_id=lax.select(e_sem_id == 0, 1, 0),
|
|
1164
|
+
local_e_id=local_num_experts - 1,
|
|
1165
|
+
)
|
|
1166
|
+
return e_sem_id
|
|
1167
|
+
|
|
1168
|
+
lax.fori_loop(0, num_bt, run_per_bt, 0, unroll=False)
|
|
1169
|
+
wait_send_bo(bt_id=num_bt - 2)
|
|
1170
|
+
wait_send_bo(bt_id=num_bt - 1)
|
|
1171
|
+
|
|
1172
|
+
### ------- Kernel end ------- ###
|
|
1173
|
+
|
|
1174
|
+
|
|
1175
|
+
@functools.partial(
|
|
1176
|
+
jax.jit,
|
|
1177
|
+
static_argnames=[
|
|
1178
|
+
"mesh",
|
|
1179
|
+
"top_k",
|
|
1180
|
+
"renormalize_topk_logits",
|
|
1181
|
+
"act_fn",
|
|
1182
|
+
"subc_quant_wsz",
|
|
1183
|
+
"bt",
|
|
1184
|
+
"bf",
|
|
1185
|
+
"bd1",
|
|
1186
|
+
"bd2",
|
|
1187
|
+
"btc",
|
|
1188
|
+
"bfc",
|
|
1189
|
+
"bd1c",
|
|
1190
|
+
"bd2c",
|
|
1191
|
+
"ep_axis_name",
|
|
1192
|
+
],
|
|
1193
|
+
)
|
|
1194
|
+
def fused_ep_moe(
|
|
1195
|
+
mesh: jax.sharding.Mesh,
|
|
1196
|
+
tokens: jax.Array, # (num_tokens, hidden_size)
|
|
1197
|
+
w1: jax.Array, # (num_experts, 2, hidden_size, intermediate_size)
|
|
1198
|
+
w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
|
|
1199
|
+
gating_output: jax.Array, # (num_tokens, num_experts)
|
|
1200
|
+
top_k: int,
|
|
1201
|
+
*,
|
|
1202
|
+
renormalize_topk_logits: bool = False,
|
|
1203
|
+
act_fn: str = "silu",
|
|
1204
|
+
subc_quant_wsz: int | None = None,
|
|
1205
|
+
w1_scale: (
|
|
1206
|
+
jax.Array | None
|
|
1207
|
+
) = None, # F32(num_experts, 2, hidden_size // subc_quant_wsz, 1, intermediate_size)
|
|
1208
|
+
w2_scale: (
|
|
1209
|
+
jax.Array | None
|
|
1210
|
+
) = None, # F32(num_experts, intermediate_size // subc_quant_wsz, 1, hidden_size)
|
|
1211
|
+
b1: jax.Array | None = None, # F32(num_experts, 2, 1, intermediate_size)
|
|
1212
|
+
b2: jax.Array | None = None, # F32(num_experts, 1, hidden_size)
|
|
1213
|
+
# Kernel tuning parameters.
|
|
1214
|
+
bt: int,
|
|
1215
|
+
bf: int,
|
|
1216
|
+
bd1: int,
|
|
1217
|
+
bd2: int,
|
|
1218
|
+
btc: int,
|
|
1219
|
+
bfc: int,
|
|
1220
|
+
bd1c: int,
|
|
1221
|
+
bd2c: int,
|
|
1222
|
+
ep_axis_name: str = "model",
|
|
1223
|
+
):
|
|
1224
|
+
# TODO(jevinjiang): move all these assertions to validation function.
|
|
1225
|
+
if len(mesh.shape) != 2:
|
|
1226
|
+
raise NotImplementedError("Only 2D mesh is supported.")
|
|
1227
|
+
|
|
1228
|
+
for axis_name in mesh.axis_names:
|
|
1229
|
+
if axis_name == ep_axis_name:
|
|
1230
|
+
continue
|
|
1231
|
+
if mesh.shape[axis_name] != 1:
|
|
1232
|
+
raise NotImplementedError(
|
|
1233
|
+
f"Expected all non-ep axis to have size 1 in {mesh.shape=}")
|
|
1234
|
+
|
|
1235
|
+
ep_size = mesh.shape[ep_axis_name]
|
|
1236
|
+
num_devices = ep_size
|
|
1237
|
+
|
|
1238
|
+
num_tokens, hidden_size = tokens.shape
|
|
1239
|
+
num_experts, intermediate_size, _ = w2.shape
|
|
1240
|
+
|
|
1241
|
+
if w1.shape != (num_experts, 2, hidden_size, intermediate_size):
|
|
1242
|
+
raise ValueError(
|
|
1243
|
+
f"Expected {w1.shape=} to be"
|
|
1244
|
+
f" {(num_experts, 2, hidden_size, intermediate_size)}.")
|
|
1245
|
+
|
|
1246
|
+
if w2.shape != (num_experts, intermediate_size, hidden_size):
|
|
1247
|
+
raise ValueError(f"Expected {w2.shape=} to be"
|
|
1248
|
+
f" {(num_experts, intermediate_size, hidden_size)}.")
|
|
1249
|
+
|
|
1250
|
+
if gating_output.shape != (num_tokens, num_experts):
|
|
1251
|
+
raise ValueError(
|
|
1252
|
+
f"Expected {gating_output.shape=} to be {(num_tokens, num_experts)}."
|
|
1253
|
+
)
|
|
1254
|
+
|
|
1255
|
+
if not (0 < top_k <= num_experts):
|
|
1256
|
+
raise ValueError(
|
|
1257
|
+
f"Expected {top_k=} to be in range (0, {num_experts=}].")
|
|
1258
|
+
|
|
1259
|
+
if hidden_size % 128 != 0 or intermediate_size % 128 != 0:
|
|
1260
|
+
raise ValueError(
|
|
1261
|
+
f"Expected {hidden_size=} and {intermediate_size=} to be aligned to"
|
|
1262
|
+
" 128. Did you pad them with zeros outside the kernel?")
|
|
1263
|
+
if num_tokens % ep_size != 0:
|
|
1264
|
+
raise ValueError(
|
|
1265
|
+
f"Expected {num_tokens=} to be aligned to {ep_size=}.")
|
|
1266
|
+
if num_experts % ep_size != 0:
|
|
1267
|
+
raise ValueError(
|
|
1268
|
+
f"Expected {num_experts=} to be aligned to {ep_size=}.")
|
|
1269
|
+
|
|
1270
|
+
local_num_tokens = num_tokens // ep_size
|
|
1271
|
+
# local_num_experts = num_experts // ep_size
|
|
1272
|
+
padded_num_experts = align_to(num_experts, 128)
|
|
1273
|
+
padded_top_k = align_to(top_k, 128)
|
|
1274
|
+
t_dtype = tokens.dtype
|
|
1275
|
+
t_packing = get_dtype_packing(t_dtype)
|
|
1276
|
+
|
|
1277
|
+
# Override bt
|
|
1278
|
+
if local_num_tokens <= t_packing * 8:
|
|
1279
|
+
bt = local_num_tokens
|
|
1280
|
+
btc = bt
|
|
1281
|
+
bt = min(local_num_tokens, bt)
|
|
1282
|
+
# The worst case is that all devices send bt to one device.
|
|
1283
|
+
btc = min(bt, btc, bt * num_devices)
|
|
1284
|
+
|
|
1285
|
+
if local_num_tokens % t_packing != 0:
|
|
1286
|
+
raise ValueError(
|
|
1287
|
+
f"Expected {local_num_tokens=} to be aligned to {t_packing=}.")
|
|
1288
|
+
|
|
1289
|
+
if bt % t_packing != 0:
|
|
1290
|
+
raise ValueError(f"Expected {bt=} to be aligned to {t_packing=}.")
|
|
1291
|
+
if local_num_tokens % bt != 0:
|
|
1292
|
+
raise ValueError(
|
|
1293
|
+
f"Expected {local_num_tokens=} to be aligned to {bt=}.")
|
|
1294
|
+
|
|
1295
|
+
if subc_quant_wsz is not None:
|
|
1296
|
+
if subc_quant_wsz <= 0:
|
|
1297
|
+
raise ValueError(f"Expected {subc_quant_wsz=} to be non-negative.")
|
|
1298
|
+
if subc_quant_wsz % 256 != 0:
|
|
1299
|
+
raise ValueError(
|
|
1300
|
+
"Expected {subc_quant_wsz=} to be aligned to 256.")
|
|
1301
|
+
if hidden_size % subc_quant_wsz != 0:
|
|
1302
|
+
raise ValueError(
|
|
1303
|
+
f"Expected {hidden_size=} to be aligned to {subc_quant_wsz=}.")
|
|
1304
|
+
if intermediate_size % subc_quant_wsz != 0:
|
|
1305
|
+
raise ValueError(
|
|
1306
|
+
f"Expected {intermediate_size=} to be aligned to {subc_quant_wsz=}."
|
|
1307
|
+
)
|
|
1308
|
+
# We force compute size of contracting dim to be subc_quant_wsz. So we can
|
|
1309
|
+
# apply same scale after matmul and accumulation.
|
|
1310
|
+
bd1c = subc_quant_wsz * t_packing
|
|
1311
|
+
bfc = subc_quant_wsz
|
|
1312
|
+
|
|
1313
|
+
if bfc % 128 != 0:
|
|
1314
|
+
raise ValueError(f"Expected {bfc=} to be aligned to 128.")
|
|
1315
|
+
if bd1c % (t_packing * 128) != 0:
|
|
1316
|
+
raise ValueError(
|
|
1317
|
+
f"Expected {bd1c=} to be aligned to {t_packing * 128}.")
|
|
1318
|
+
if bd2c % (t_packing * 128) != 0:
|
|
1319
|
+
raise ValueError(
|
|
1320
|
+
f"Expected {bd2c=} to be aligned to {t_packing * 128}.")
|
|
1321
|
+
if bf % bfc != 0:
|
|
1322
|
+
raise ValueError(f"Expected {bf=} to be aligned to {bfc=}.")
|
|
1323
|
+
if bd1 % bd1c != 0:
|
|
1324
|
+
raise ValueError(f"Expected {bd1=} to be aligned to {bd1c=}.")
|
|
1325
|
+
if bd2 % bd2c != 0:
|
|
1326
|
+
raise ValueError(f"Expected {bd2=} to be aligned to {bd2c=}.")
|
|
1327
|
+
if hidden_size % bd1 != 0 or hidden_size % bd2 != 0:
|
|
1328
|
+
raise ValueError(
|
|
1329
|
+
f"Expected {hidden_size=} to be aligned to {bd1=} and {bd2=}.")
|
|
1330
|
+
if intermediate_size % bf != 0:
|
|
1331
|
+
raise ValueError(
|
|
1332
|
+
f"Expected {intermediate_size=} to be aligned to {bf=}.")
|
|
1333
|
+
|
|
1334
|
+
# Note: we should dump scale as the kernel expected shape in the
|
|
1335
|
+
# checkpoint offline or reshape right after weight loading.
|
|
1336
|
+
if w1_scale is not None:
|
|
1337
|
+
expected_w1_scale_shape = (
|
|
1338
|
+
num_experts,
|
|
1339
|
+
2,
|
|
1340
|
+
hidden_size // subc_quant_wsz,
|
|
1341
|
+
1,
|
|
1342
|
+
intermediate_size,
|
|
1343
|
+
)
|
|
1344
|
+
if w1_scale.shape != expected_w1_scale_shape:
|
|
1345
|
+
raise ValueError(
|
|
1346
|
+
f"Expected {w1_scale.shape=} to be {expected_w1_scale_shape}.")
|
|
1347
|
+
if w1_scale.dtype != jnp.float32:
|
|
1348
|
+
w1_scale = w1_scale.astype(jnp.float32)
|
|
1349
|
+
|
|
1350
|
+
if w2_scale is not None:
|
|
1351
|
+
expected_w2_scale_shape = (
|
|
1352
|
+
num_experts,
|
|
1353
|
+
intermediate_size // subc_quant_wsz,
|
|
1354
|
+
1,
|
|
1355
|
+
hidden_size,
|
|
1356
|
+
)
|
|
1357
|
+
if w2_scale.shape != expected_w2_scale_shape:
|
|
1358
|
+
raise ValueError(
|
|
1359
|
+
f"Expected {w2_scale.shape=} to be {expected_w2_scale_shape}.")
|
|
1360
|
+
if w2_scale.dtype != jnp.float32:
|
|
1361
|
+
w2_scale = w2_scale.astype(jnp.float32)
|
|
1362
|
+
|
|
1363
|
+
if b1 is not None:
|
|
1364
|
+
expected_b1_shape = (num_experts, 2, 1, intermediate_size)
|
|
1365
|
+
if b1.shape != expected_b1_shape:
|
|
1366
|
+
raise ValueError(
|
|
1367
|
+
f"Expected {b1.shape=} to be {expected_b1_shape}.")
|
|
1368
|
+
if b1.dtype != jnp.float32:
|
|
1369
|
+
b1 = b1.astype(jnp.float32)
|
|
1370
|
+
|
|
1371
|
+
if b2 is not None:
|
|
1372
|
+
expected_b2_shape = (num_experts, 1, hidden_size)
|
|
1373
|
+
if b2.shape != expected_b2_shape:
|
|
1374
|
+
raise ValueError(
|
|
1375
|
+
f"Expected {b2.shape=} to be {expected_b2_shape}.")
|
|
1376
|
+
if b2.dtype != jnp.float32:
|
|
1377
|
+
b2 = b2.astype(jnp.float32)
|
|
1378
|
+
|
|
1379
|
+
# Prepare inputs for the kernel.
|
|
1380
|
+
if padded_num_experts != gating_output.shape[-1]:
|
|
1381
|
+
gating_output = jnp.pad(
|
|
1382
|
+
gating_output,
|
|
1383
|
+
((0, 0), (0, padded_num_experts - gating_output.shape[-1])),
|
|
1384
|
+
constant_values=-jnp.inf,
|
|
1385
|
+
)
|
|
1386
|
+
|
|
1387
|
+
tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing)
|
|
1388
|
+
|
|
1389
|
+
hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM)
|
|
1390
|
+
renorm_str = "-renorm_k" if renormalize_topk_logits else ""
|
|
1391
|
+
scope_name = f"fused-moe-k_{top_k}{renorm_str}-bt_{bt}_{btc}-bf_{bf}_{bfc}-bd1_{bd1}_{bd1c}-bd2_{bd2}_{bd2c}"
|
|
1392
|
+
fused_moe = pl.pallas_call(
|
|
1393
|
+
functools.partial(
|
|
1394
|
+
_fused_ep_moe_kernel,
|
|
1395
|
+
top_k=top_k,
|
|
1396
|
+
renormalize_topk_logits=renormalize_topk_logits,
|
|
1397
|
+
ep_axis_name=ep_axis_name,
|
|
1398
|
+
act_fn=act_fn,
|
|
1399
|
+
subc_quant_wsz=subc_quant_wsz,
|
|
1400
|
+
bt=bt,
|
|
1401
|
+
bf=bf,
|
|
1402
|
+
bd1=bd1,
|
|
1403
|
+
bd2=bd2,
|
|
1404
|
+
btc=btc,
|
|
1405
|
+
bfc=bfc,
|
|
1406
|
+
bd1c=bd1c,
|
|
1407
|
+
bd2c=bd2c,
|
|
1408
|
+
),
|
|
1409
|
+
out_shape=jax.ShapeDtypeStruct((local_num_tokens, hidden_size),
|
|
1410
|
+
t_dtype),
|
|
1411
|
+
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
1412
|
+
num_scalar_prefetch=0,
|
|
1413
|
+
in_specs=[
|
|
1414
|
+
hbm_block_spec, # tokens_hbm
|
|
1415
|
+
hbm_block_spec, # w1_hbm
|
|
1416
|
+
hbm_block_spec, # w2_hbm
|
|
1417
|
+
None if w1_scale is None else hbm_block_spec, # w1_scale_hbm
|
|
1418
|
+
None if w2_scale is None else hbm_block_spec, # w2_scale_hbm
|
|
1419
|
+
None if b1 is None else hbm_block_spec, # b1_hbm
|
|
1420
|
+
None if b2 is None else hbm_block_spec, # b2_hbm
|
|
1421
|
+
hbm_block_spec, # gating_output_hbm
|
|
1422
|
+
hbm_block_spec, # a2a_g_hbm
|
|
1423
|
+
],
|
|
1424
|
+
out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
|
|
1425
|
+
scratch_shapes=([
|
|
1426
|
+
# t2e_routing_x2_smem
|
|
1427
|
+
pltpu.SMEM((2, bt, padded_top_k), jnp.int32),
|
|
1428
|
+
# d2e_count_x2_smem
|
|
1429
|
+
pltpu.SMEM((2, num_devices, 1, padded_num_experts), jnp.int32),
|
|
1430
|
+
# expert_offsets_x2_smem
|
|
1431
|
+
pltpu.SMEM((2, 2, padded_num_experts), jnp.int32),
|
|
1432
|
+
# expert_starts_x2_smem
|
|
1433
|
+
pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
|
|
1434
|
+
# expert_sizes_x2_smem
|
|
1435
|
+
pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
|
|
1436
|
+
# a2a_s_sends_x2_smem
|
|
1437
|
+
pltpu.SMEM((2, ), jnp.int32),
|
|
1438
|
+
# a2a_s_x2_vmem
|
|
1439
|
+
pltpu.VMEM(
|
|
1440
|
+
(
|
|
1441
|
+
2,
|
|
1442
|
+
bt * num_devices,
|
|
1443
|
+
t_packing,
|
|
1444
|
+
hidden_size // t_packing,
|
|
1445
|
+
),
|
|
1446
|
+
t_dtype,
|
|
1447
|
+
),
|
|
1448
|
+
# a2a_s_acc_x2_vmem
|
|
1449
|
+
pltpu.VMEM(
|
|
1450
|
+
(
|
|
1451
|
+
2,
|
|
1452
|
+
bt * num_devices,
|
|
1453
|
+
t_packing,
|
|
1454
|
+
hidden_size // t_packing,
|
|
1455
|
+
),
|
|
1456
|
+
t_dtype,
|
|
1457
|
+
),
|
|
1458
|
+
# a2a_g_acc_vmem
|
|
1459
|
+
pltpu.VMEM((top_k, bt, t_packing, hidden_size // t_packing),
|
|
1460
|
+
t_dtype),
|
|
1461
|
+
# b_gating_x2_vmem
|
|
1462
|
+
pltpu.VMEM((2, bt, padded_num_experts), t_dtype),
|
|
1463
|
+
# b_output_x2_vmem
|
|
1464
|
+
pltpu.VMEM((2, bt, hidden_size), t_dtype),
|
|
1465
|
+
# b_w1_x2_vmem
|
|
1466
|
+
pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
|
|
1467
|
+
# b_w3_x2_vmem
|
|
1468
|
+
pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
|
|
1469
|
+
# b_w2_x2_vmem
|
|
1470
|
+
pltpu.VMEM((2, t_packing, bf, bd2 // t_packing), w2.dtype),
|
|
1471
|
+
# b_w1_scale_x2_vmem
|
|
1472
|
+
(None if w1_scale is None else pltpu.VMEM(
|
|
1473
|
+
(
|
|
1474
|
+
2,
|
|
1475
|
+
t_packing,
|
|
1476
|
+
bd1 // t_packing // subc_quant_wsz,
|
|
1477
|
+
1,
|
|
1478
|
+
bf,
|
|
1479
|
+
),
|
|
1480
|
+
jnp.float32,
|
|
1481
|
+
)),
|
|
1482
|
+
# b_w3_scale_x2_vmem
|
|
1483
|
+
(None if w1_scale is None else pltpu.VMEM(
|
|
1484
|
+
(
|
|
1485
|
+
2,
|
|
1486
|
+
t_packing,
|
|
1487
|
+
bd1 // t_packing // subc_quant_wsz,
|
|
1488
|
+
1,
|
|
1489
|
+
bf,
|
|
1490
|
+
),
|
|
1491
|
+
jnp.float32,
|
|
1492
|
+
)),
|
|
1493
|
+
# b_w2_scale_x2_vmem
|
|
1494
|
+
(None if w2_scale is None else pltpu.VMEM(
|
|
1495
|
+
(
|
|
1496
|
+
2,
|
|
1497
|
+
t_packing,
|
|
1498
|
+
bf // subc_quant_wsz,
|
|
1499
|
+
1,
|
|
1500
|
+
bd2 // t_packing,
|
|
1501
|
+
),
|
|
1502
|
+
jnp.float32,
|
|
1503
|
+
)),
|
|
1504
|
+
# b_b1_x2_vmem
|
|
1505
|
+
(None if b1 is None else pltpu.VMEM(
|
|
1506
|
+
(
|
|
1507
|
+
2,
|
|
1508
|
+
1,
|
|
1509
|
+
bf,
|
|
1510
|
+
),
|
|
1511
|
+
jnp.float32,
|
|
1512
|
+
)),
|
|
1513
|
+
# b_b3_x2_vmem
|
|
1514
|
+
(None if b1 is None else pltpu.VMEM(
|
|
1515
|
+
(
|
|
1516
|
+
2,
|
|
1517
|
+
1,
|
|
1518
|
+
bf,
|
|
1519
|
+
),
|
|
1520
|
+
jnp.float32,
|
|
1521
|
+
)),
|
|
1522
|
+
# b_b2_x2_vmem
|
|
1523
|
+
(None if b2 is None else pltpu.VMEM(
|
|
1524
|
+
(
|
|
1525
|
+
2,
|
|
1526
|
+
t_packing,
|
|
1527
|
+
1,
|
|
1528
|
+
bd2 // t_packing,
|
|
1529
|
+
),
|
|
1530
|
+
jnp.float32,
|
|
1531
|
+
)),
|
|
1532
|
+
# b_acc_vmem
|
|
1533
|
+
pltpu.VMEM((bt * num_devices, 1, bf * 2), jnp.float32),
|
|
1534
|
+
# local_sems
|
|
1535
|
+
pltpu.SemaphoreType.DMA((2, 5)),
|
|
1536
|
+
# send_sems
|
|
1537
|
+
pltpu.SemaphoreType.DMA((2, )),
|
|
1538
|
+
# recv_sems
|
|
1539
|
+
pltpu.SemaphoreType.DMA((2, )),
|
|
1540
|
+
# a2a_gather_sem
|
|
1541
|
+
pltpu.SemaphoreType.DMA,
|
|
1542
|
+
# a2a_acc_sem
|
|
1543
|
+
pltpu.SemaphoreType.DMA,
|
|
1544
|
+
]),
|
|
1545
|
+
),
|
|
1546
|
+
compiler_params=pltpu.CompilerParams(
|
|
1547
|
+
collective_id=0,
|
|
1548
|
+
vmem_limit_bytes=100 * 1024 * 1024,
|
|
1549
|
+
),
|
|
1550
|
+
name=scope_name,
|
|
1551
|
+
)
|
|
1552
|
+
|
|
1553
|
+
@jax.jit
|
|
1554
|
+
@jax.shard_map(
|
|
1555
|
+
mesh=mesh,
|
|
1556
|
+
in_specs=(
|
|
1557
|
+
P(ep_axis_name), # tokens_hbm
|
|
1558
|
+
P(ep_axis_name), # w1_hbm
|
|
1559
|
+
P(ep_axis_name), # w2_hbm
|
|
1560
|
+
None if w1_scale is None else P(ep_axis_name), # w1_scale_hbm
|
|
1561
|
+
None if w2_scale is None else P(ep_axis_name), # w2_scale_hbm
|
|
1562
|
+
None if b1 is None else P(ep_axis_name), # b1_hbm
|
|
1563
|
+
None if b2 is None else P(ep_axis_name), # b2_hbm
|
|
1564
|
+
P(ep_axis_name), # gating_output_hbm
|
|
1565
|
+
P(), # a2a_g_hbm
|
|
1566
|
+
),
|
|
1567
|
+
out_specs=P(ep_axis_name),
|
|
1568
|
+
check_vma=False,
|
|
1569
|
+
)
|
|
1570
|
+
def kernel(
|
|
1571
|
+
tokens,
|
|
1572
|
+
w1,
|
|
1573
|
+
w2,
|
|
1574
|
+
w1_scale,
|
|
1575
|
+
w2_scale,
|
|
1576
|
+
b1,
|
|
1577
|
+
b2,
|
|
1578
|
+
gating_output,
|
|
1579
|
+
a2a_g_hbm_scratch,
|
|
1580
|
+
):
|
|
1581
|
+
return fused_moe(
|
|
1582
|
+
pltpu.with_memory_space_constraint(tokens,
|
|
1583
|
+
pltpu.HBM), # tokens_hbm
|
|
1584
|
+
pltpu.with_memory_space_constraint(w1, pltpu.HBM), # w1_hbm
|
|
1585
|
+
pltpu.with_memory_space_constraint(w2, pltpu.HBM), # w2_hbm
|
|
1586
|
+
(None if w1_scale is None else pltpu.with_memory_space_constraint(
|
|
1587
|
+
w1_scale, pltpu.HBM)), # w1_scale_hbm
|
|
1588
|
+
(None if w2_scale is None else pltpu.with_memory_space_constraint(
|
|
1589
|
+
w2_scale, pltpu.HBM)), # w2_scale_hbm
|
|
1590
|
+
(None if b1 is None else pltpu.with_memory_space_constraint(
|
|
1591
|
+
b1, pltpu.HBM)), # b1_hbm
|
|
1592
|
+
(None if b2 is None else pltpu.with_memory_space_constraint(
|
|
1593
|
+
b2, pltpu.HBM)), # b2_hbm
|
|
1594
|
+
pltpu.with_memory_space_constraint(gating_output,
|
|
1595
|
+
pltpu.HBM), # gating_output_hbm
|
|
1596
|
+
pltpu.with_memory_space_constraint(a2a_g_hbm_scratch,
|
|
1597
|
+
pltpu.HBM), # a2a_g_hbm
|
|
1598
|
+
)
|
|
1599
|
+
|
|
1600
|
+
a2a_g_hbm_scratch = pl.empty(
|
|
1601
|
+
(num_experts, bt, t_packing, hidden_size // t_packing), t_dtype)
|
|
1602
|
+
return kernel(
|
|
1603
|
+
tokens,
|
|
1604
|
+
w1,
|
|
1605
|
+
w2,
|
|
1606
|
+
w1_scale,
|
|
1607
|
+
w2_scale,
|
|
1608
|
+
b1,
|
|
1609
|
+
b2,
|
|
1610
|
+
gating_output,
|
|
1611
|
+
a2a_g_hbm_scratch,
|
|
1612
|
+
)
|