tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +45 -15
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +41 -16
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -1,3 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
1
14
|
"""TPU-Friendly Fused Mixture of Experts (MoE) kernel."""
|
|
2
15
|
|
|
3
16
|
import functools
|
|
@@ -19,7 +32,8 @@ def align_to(x, a):
|
|
|
19
32
|
|
|
20
33
|
|
|
21
34
|
def get_dtype_packing(dtype):
|
|
22
|
-
bits = dtypes.bit_width(dtype)
|
|
35
|
+
bits = (dtypes.bit_width(dtype)
|
|
36
|
+
if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
|
|
23
37
|
return 32 // bits
|
|
24
38
|
|
|
25
39
|
|
|
@@ -65,18 +79,19 @@ def ref_moe(
|
|
|
65
79
|
top_k: int,
|
|
66
80
|
*,
|
|
67
81
|
renormalize_topk_logits: bool = False,
|
|
68
|
-
|
|
82
|
+
act_fn: str = "silu",
|
|
69
83
|
subc_quant_wsz: int | None = None,
|
|
70
84
|
w1_scale:
|
|
71
85
|
(
|
|
72
86
|
jax.Array | None
|
|
73
|
-
) = None, # (num_experts, 2,
|
|
87
|
+
) = None, # F32(num_experts, 2, hidden_size //subc_quant_wsz, 1, intermediate_size)
|
|
74
88
|
w2_scale:
|
|
75
89
|
(
|
|
76
90
|
jax.Array | None
|
|
77
|
-
) = None, # (num_experts,
|
|
78
|
-
b1: jax.Array
|
|
79
|
-
|
|
91
|
+
) = None, # F32(num_experts, intermediate_size // subc_quant_wsz, 1, hidden_size)
|
|
92
|
+
b1: jax.Array
|
|
93
|
+
| None = None, # F32(num_experts, 2, 1, intermediate_size)
|
|
94
|
+
b2: jax.Array | None = None, # F32(num_experts, 1, hidden_size)
|
|
80
95
|
):
|
|
81
96
|
n_tokens = tokens.shape[0] # num_tokens
|
|
82
97
|
|
|
@@ -97,7 +112,7 @@ def ref_moe(
|
|
|
97
112
|
|
|
98
113
|
# Process each token individually
|
|
99
114
|
for i in range(n_tokens):
|
|
100
|
-
curr_token = jnp.expand_dims(tokens[i], axis=0) # [1,
|
|
115
|
+
curr_token = jnp.expand_dims(tokens[i], axis=0) # [1, hidden_size]
|
|
101
116
|
assigned_expert_ids = top_k_indices[
|
|
102
117
|
i] # [top_k] - indices of selected experts for token i
|
|
103
118
|
tok_expert_act = []
|
|
@@ -108,19 +123,19 @@ def ref_moe(
|
|
|
108
123
|
expert_w1 = w1[expert_id, 0].astype(jnp.float32)
|
|
109
124
|
expert_w3 = w1[expert_id, 1].astype(jnp.float32)
|
|
110
125
|
if w1_scale is not None:
|
|
111
|
-
expert_w1 *= jnp.repeat(w1_scale[expert_id, 0],
|
|
126
|
+
expert_w1 *= jnp.repeat(w1_scale[expert_id, 0, :, 0],
|
|
112
127
|
subc_quant_wsz,
|
|
113
128
|
axis=0)[:hidden_size]
|
|
114
|
-
expert_w3 *= jnp.repeat(w1_scale[expert_id, 1],
|
|
129
|
+
expert_w3 *= jnp.repeat(w1_scale[expert_id, 1, :, 0],
|
|
115
130
|
subc_quant_wsz,
|
|
116
131
|
axis=0)[:hidden_size]
|
|
117
132
|
expert_weight_1 = jnp.concat(
|
|
118
133
|
[expert_w1, expert_w3],
|
|
119
|
-
axis=-1) # [
|
|
134
|
+
axis=-1) # [hidden_size, 2 * intermediate_size]
|
|
120
135
|
expert_weight_2 = w2[expert_id].astype(
|
|
121
|
-
jnp.float32) # [intermediate_size,
|
|
136
|
+
jnp.float32) # [intermediate_size, hidden_size]
|
|
122
137
|
if w2_scale is not None:
|
|
123
|
-
expert_weight_2 *= jnp.repeat(w2_scale[expert_id],
|
|
138
|
+
expert_weight_2 *= jnp.repeat(w2_scale[expert_id, :, 0],
|
|
124
139
|
subc_quant_wsz,
|
|
125
140
|
axis=0)[:intermediate_size]
|
|
126
141
|
|
|
@@ -132,32 +147,33 @@ def ref_moe(
|
|
|
132
147
|
gmm_1_out, 2,
|
|
133
148
|
axis=-1) # [1, intermediate_size], [1, intermediate_size]
|
|
134
149
|
if b1 is not None:
|
|
135
|
-
gmm1_w1_proj += b1[expert_id:expert_id + 1, 0]
|
|
136
|
-
gmm1_w3_proj += b1[expert_id:expert_id + 1, 1]
|
|
150
|
+
gmm1_w1_proj += b1[expert_id:expert_id + 1, 0, 0]
|
|
151
|
+
gmm1_w3_proj += b1[expert_id:expert_id + 1, 1, 0]
|
|
137
152
|
|
|
138
153
|
# Apply gated activation: activation(gate) * up
|
|
139
|
-
act = activation_fn(gmm1_w1_proj, gmm1_w3_proj,
|
|
154
|
+
act = activation_fn(gmm1_w1_proj, gmm1_w3_proj, act_fn)
|
|
140
155
|
|
|
141
156
|
# Second linear layer (down projection)
|
|
142
|
-
gmm_2_out = act @ expert_weight_2 # [1,
|
|
157
|
+
gmm_2_out = act @ expert_weight_2 # [1, hidden_size]
|
|
143
158
|
if b2 is not None:
|
|
144
|
-
gmm_2_out += b2[expert_id:expert_id + 1]
|
|
159
|
+
gmm_2_out += b2[expert_id:expert_id + 1, 0]
|
|
145
160
|
tok_expert_act.append(gmm_2_out)
|
|
146
161
|
|
|
147
162
|
# Combine outputs from all selected experts
|
|
148
163
|
experts_act = jnp.concatenate(tok_expert_act,
|
|
149
|
-
axis=0) # [top_k,
|
|
164
|
+
axis=0) # [top_k, hidden_size]
|
|
150
165
|
|
|
151
166
|
# Weighted sum using top-k gating weights
|
|
152
167
|
top_k_weights = top_k_logits[i] # [top_k]
|
|
153
168
|
top_k_weights = jnp.expand_dims(top_k_weights, axis=1) # [top_k, 1]
|
|
154
169
|
weighted_output = jnp.sum(experts_act * top_k_weights,
|
|
155
170
|
axis=0,
|
|
156
|
-
keepdims=True) # [1,
|
|
171
|
+
keepdims=True) # [1, hidden_size]
|
|
157
172
|
|
|
158
173
|
t_outputs.append(weighted_output.astype(tokens.dtype))
|
|
159
174
|
|
|
160
|
-
return jnp.concatenate(t_outputs,
|
|
175
|
+
return jnp.concatenate(t_outputs,
|
|
176
|
+
axis=0) # [actual_num_tokens, hidden_size]
|
|
161
177
|
|
|
162
178
|
|
|
163
179
|
def _fused_ep_moe_kernel(
|
|
@@ -177,7 +193,7 @@ def _fused_ep_moe_kernel(
|
|
|
177
193
|
# Output
|
|
178
194
|
output_hbm, # (local_num_tokens, hidden_size)
|
|
179
195
|
# Scratch
|
|
180
|
-
t2e_routing_x2_smem, # <bt_sem_id> (2, bt,
|
|
196
|
+
t2e_routing_x2_smem, # <bt_sem_id> (2, bt, padded_top_k)
|
|
181
197
|
d2e_count_x2_smem, # <bt_sem_id> (2, num_devices, 1, padded_num_experts)
|
|
182
198
|
expert_offsets_x2_smem, # <bt_sem_id> (2, 2, padded_num_experts): for a2a_s and a2a_g
|
|
183
199
|
expert_starts_x2_smem, # <bt_sem_id> (2, 1, padded_num_experts)
|
|
@@ -227,6 +243,11 @@ def _fused_ep_moe_kernel(
|
|
|
227
243
|
local_num_tokens = tokens_hbm.shape[0]
|
|
228
244
|
local_num_experts, intermediate_size, hidden_size = w2_hbm.shape
|
|
229
245
|
right_id = (my_id + 1) % num_devices
|
|
246
|
+
num_experts = a2a_g_hbm.shape[0]
|
|
247
|
+
padded_num_experts = d2e_count_x2_smem.shape[-1]
|
|
248
|
+
padded_top_k = t2e_routing_x2_smem.shape[-1]
|
|
249
|
+
assert padded_num_experts == align_to(num_experts, 128)
|
|
250
|
+
assert padded_top_k == align_to(top_k, 128)
|
|
230
251
|
|
|
231
252
|
t_dtype = tokens_hbm.dtype
|
|
232
253
|
t_packing = get_dtype_packing(t_dtype)
|
|
@@ -300,35 +321,40 @@ def _fused_ep_moe_kernel(
|
|
|
300
321
|
def get_top_k(input, top_k, renormalize_topk_logits):
|
|
301
322
|
assert len(input.shape) == 2, input.shape
|
|
302
323
|
input = input.astype(jnp.float32)
|
|
324
|
+
padded_k_shape = (input.shape[0], padded_top_k)
|
|
303
325
|
top_k_logits_lst = []
|
|
304
326
|
top_k_indices_lst = []
|
|
305
327
|
t2e = jnp.zeros(input.shape, dtype=jnp.int32)
|
|
306
|
-
t2e_routing = jnp.zeros(
|
|
328
|
+
t2e_routing = jnp.zeros(padded_k_shape, dtype=jnp.int32)
|
|
307
329
|
iota = jax.lax.broadcasted_iota(jnp.int32, input.shape, 1)
|
|
308
|
-
|
|
330
|
+
padded_k_iota = jax.lax.broadcasted_iota(jnp.int32, padded_k_shape, 1)
|
|
331
|
+
top_k_logits_sum = jnp.zeros(padded_k_shape, jnp.float32)
|
|
309
332
|
|
|
310
333
|
for k_id in range(top_k):
|
|
311
334
|
# TODO(jevinjiang): return both top_k values and indices in Mosaic
|
|
312
335
|
top_k_logits = jnp.broadcast_to(
|
|
313
|
-
jnp.max(input, axis=1, keepdims=True),
|
|
314
|
-
|
|
336
|
+
jnp.max(input[:, :num_experts], axis=1, keepdims=True),
|
|
337
|
+
padded_k_shape,
|
|
338
|
+
).astype(input.dtype)
|
|
339
|
+
top_k_logits_lst.append(top_k_logits)
|
|
315
340
|
if renormalize_topk_logits:
|
|
316
341
|
top_k_logits_sum += top_k_logits
|
|
317
|
-
top_k_logits_lst.append(top_k_logits)
|
|
318
342
|
# TODO(jevinjiang): support bf16 argmax in Mosaic
|
|
319
343
|
top_k_indices = jnp.broadcast_to(
|
|
320
|
-
jnp.argmax(input, axis=1, keepdims=True),
|
|
344
|
+
jnp.argmax(input[:, :num_experts], axis=1, keepdims=True),
|
|
345
|
+
padded_k_shape,
|
|
346
|
+
)
|
|
321
347
|
top_k_indices_lst.append(top_k_indices)
|
|
322
|
-
t2e_routing = jnp.where(
|
|
323
|
-
|
|
348
|
+
t2e_routing = jnp.where(padded_k_iota == k_id, top_k_indices,
|
|
349
|
+
t2e_routing)
|
|
350
|
+
mask = iota == broadcast_minor(top_k_indices, input.shape)
|
|
324
351
|
t2e += mask.astype(jnp.int32)
|
|
325
352
|
if k_id != top_k - 1:
|
|
326
353
|
input = jnp.where(mask, -jnp.inf, input)
|
|
327
354
|
|
|
328
355
|
if renormalize_topk_logits:
|
|
329
356
|
for k_id in range(top_k):
|
|
330
|
-
top_k_logits_lst[
|
|
331
|
-
k_id] = top_k_logits_lst[k_id] / top_k_logits_sum
|
|
357
|
+
top_k_logits_lst[k_id] /= top_k_logits_sum
|
|
332
358
|
|
|
333
359
|
expert_sizes = jnp.sum(t2e, axis=0, keepdims=True)
|
|
334
360
|
expert_starts = jnp.zeros_like(expert_sizes)
|
|
@@ -1071,27 +1097,38 @@ def _fused_ep_moe_kernel(
|
|
|
1071
1097
|
|
|
1072
1098
|
all_reduce_metadata(bt_sem_id, t2e_routing, expert_starts,
|
|
1073
1099
|
expert_sizes)
|
|
1100
|
+
sync_barrier()
|
|
1074
1101
|
|
|
1102
|
+
# Start a2a scatter for first active expert.
|
|
1075
1103
|
start_a2a_scatter(bt_id=bt_id, e_sem_id=e_sem_id, local_e_id=0)
|
|
1076
1104
|
|
|
1077
1105
|
def run_per_expert(local_e_id, e_sem_id):
|
|
1078
1106
|
sync_barrier()
|
|
1107
|
+
|
|
1108
|
+
# Prefetch weights for CURRENT active expert.
|
|
1109
|
+
# TODO(jevinjiang): It is hard to prefetch weights in previous iteration
|
|
1110
|
+
# because the expert_ffn keeps overwriting the buffers. Triple buffering
|
|
1111
|
+
# could resolve this but it takes more VMEM scratch. Need further
|
|
1112
|
+
# experiment on this.
|
|
1113
|
+
start_fetch_bw1(local_e_id, bw1_sem_id=0, bf_id=0, bd1_id=0)
|
|
1114
|
+
start_fetch_bw3(local_e_id, bw3_sem_id=0, bf_id=0, bd3_id=0)
|
|
1115
|
+
|
|
1116
|
+
# Next ids.
|
|
1079
1117
|
next_e_sem_id = lax.select(e_sem_id == 0, 1, 0)
|
|
1080
1118
|
next_local_e_id = local_e_id + 1
|
|
1081
1119
|
|
|
1120
|
+
# Start a2a scatter for NEXT active expert.
|
|
1082
1121
|
@pl.when(next_local_e_id < local_num_experts)
|
|
1083
1122
|
def _():
|
|
1084
1123
|
start_a2a_scatter(bt_id, next_e_sem_id, next_local_e_id)
|
|
1085
1124
|
|
|
1086
|
-
#
|
|
1087
|
-
start_fetch_bw1(local_e_id, bw1_sem_id=0, bf_id=0, bd1_id=0)
|
|
1088
|
-
start_fetch_bw3(local_e_id, bw3_sem_id=0, bf_id=0, bd3_id=0)
|
|
1089
|
-
|
|
1090
|
-
# Wait for a2a scatter and perform FFN for active expert.
|
|
1125
|
+
# Wait a2a scatter for CURRENT active expert.
|
|
1091
1126
|
wait_a2a_scatter_recv(bt_id, e_sem_id, local_e_id)
|
|
1127
|
+
|
|
1128
|
+
# Perform FFN for CURRENT active expert.
|
|
1092
1129
|
expert_ffn(bt_id, e_sem_id, local_e_id)
|
|
1093
1130
|
|
|
1094
|
-
#
|
|
1131
|
+
# Start a2a gather to send back tokens for CURRENT active expert.
|
|
1095
1132
|
start_a2a_gather(bt_id, e_sem_id, local_e_id)
|
|
1096
1133
|
|
|
1097
1134
|
# A must-wait before next sync_barrier.
|
|
@@ -1104,7 +1141,10 @@ def _fused_ep_moe_kernel(
|
|
|
1104
1141
|
e_sem_id,
|
|
1105
1142
|
unroll=False)
|
|
1106
1143
|
|
|
1144
|
+
# Wait to receive a2a gather for ALL experts.
|
|
1107
1145
|
wait_a2a_gather_recv_all()
|
|
1146
|
+
|
|
1147
|
+
# Accumulate results for current batch.
|
|
1108
1148
|
output = bt_acc(bt_id, top_k_logits_lst)
|
|
1109
1149
|
|
|
1110
1150
|
# Make sure it is safe to overwrite output buffer.
|
|
@@ -1158,18 +1198,18 @@ def fused_ep_moe(
|
|
|
1158
1198
|
w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
|
|
1159
1199
|
gating_output: jax.Array, # (num_tokens, num_experts)
|
|
1160
1200
|
top_k: int,
|
|
1201
|
+
*,
|
|
1161
1202
|
renormalize_topk_logits: bool = False,
|
|
1162
1203
|
act_fn: str = "silu",
|
|
1163
|
-
*,
|
|
1164
1204
|
subc_quant_wsz: int | None = None,
|
|
1165
1205
|
w1_scale: (
|
|
1166
1206
|
jax.Array | None
|
|
1167
|
-
) = None, # (num_experts, 2,
|
|
1207
|
+
) = None, # F32(num_experts, 2, hidden_size // subc_quant_wsz, 1, intermediate_size)
|
|
1168
1208
|
w2_scale: (
|
|
1169
1209
|
jax.Array | None
|
|
1170
|
-
) = None, # (num_experts,
|
|
1171
|
-
b1: jax.Array | None = None, # (num_experts, 2, intermediate_size)
|
|
1172
|
-
b2: jax.Array | None = None, # (num_experts, hidden_size)
|
|
1210
|
+
) = None, # F32(num_experts, intermediate_size // subc_quant_wsz, 1, hidden_size)
|
|
1211
|
+
b1: jax.Array | None = None, # F32(num_experts, 2, 1, intermediate_size)
|
|
1212
|
+
b2: jax.Array | None = None, # F32(num_experts, 1, hidden_size)
|
|
1173
1213
|
# Kernel tuning parameters.
|
|
1174
1214
|
bt: int,
|
|
1175
1215
|
bf: int,
|
|
@@ -1182,75 +1222,159 @@ def fused_ep_moe(
|
|
|
1182
1222
|
ep_axis_name: str = "model",
|
|
1183
1223
|
):
|
|
1184
1224
|
# TODO(jevinjiang): move all these assertions to validation function.
|
|
1185
|
-
|
|
1186
|
-
|
|
1187
|
-
|
|
1188
|
-
|
|
1225
|
+
if len(mesh.shape) != 2:
|
|
1226
|
+
raise NotImplementedError("Only 2D mesh is supported.")
|
|
1227
|
+
|
|
1228
|
+
for axis_name in mesh.axis_names:
|
|
1229
|
+
if axis_name == ep_axis_name:
|
|
1230
|
+
continue
|
|
1231
|
+
if mesh.shape[axis_name] != 1:
|
|
1232
|
+
raise NotImplementedError(
|
|
1233
|
+
f"Expected all non-ep axis to have size 1 in {mesh.shape=}")
|
|
1189
1234
|
|
|
1190
1235
|
ep_size = mesh.shape[ep_axis_name]
|
|
1191
1236
|
num_devices = ep_size
|
|
1192
1237
|
|
|
1193
|
-
num_tokens,
|
|
1194
|
-
num_experts,
|
|
1238
|
+
num_tokens, hidden_size = tokens.shape
|
|
1239
|
+
num_experts, intermediate_size, _ = w2.shape
|
|
1240
|
+
|
|
1241
|
+
if w1.shape != (num_experts, 2, hidden_size, intermediate_size):
|
|
1242
|
+
raise ValueError(
|
|
1243
|
+
f"Expected {w1.shape=} to be"
|
|
1244
|
+
f" {(num_experts, 2, hidden_size, intermediate_size)}.")
|
|
1245
|
+
|
|
1246
|
+
if w2.shape != (num_experts, intermediate_size, hidden_size):
|
|
1247
|
+
raise ValueError(f"Expected {w2.shape=} to be"
|
|
1248
|
+
f" {(num_experts, intermediate_size, hidden_size)}.")
|
|
1195
1249
|
|
|
1196
|
-
|
|
1197
|
-
|
|
1250
|
+
if gating_output.shape != (num_tokens, num_experts):
|
|
1251
|
+
raise ValueError(
|
|
1252
|
+
f"Expected {gating_output.shape=} to be {(num_tokens, num_experts)}."
|
|
1253
|
+
)
|
|
1254
|
+
|
|
1255
|
+
if not (0 < top_k <= num_experts):
|
|
1256
|
+
raise ValueError(
|
|
1257
|
+
f"Expected {top_k=} to be in range (0, {num_experts=}].")
|
|
1258
|
+
|
|
1259
|
+
if hidden_size % 128 != 0 or intermediate_size % 128 != 0:
|
|
1260
|
+
raise ValueError(
|
|
1261
|
+
f"Expected {hidden_size=} and {intermediate_size=} to be aligned to"
|
|
1262
|
+
" 128. Did you pad them with zeros outside the kernel?")
|
|
1263
|
+
if num_tokens % ep_size != 0:
|
|
1264
|
+
raise ValueError(
|
|
1265
|
+
f"Expected {num_tokens=} to be aligned to {ep_size=}.")
|
|
1266
|
+
if num_experts % ep_size != 0:
|
|
1267
|
+
raise ValueError(
|
|
1268
|
+
f"Expected {num_experts=} to be aligned to {ep_size=}.")
|
|
1198
1269
|
|
|
1199
1270
|
local_num_tokens = num_tokens // ep_size
|
|
1200
1271
|
# local_num_experts = num_experts // ep_size
|
|
1201
1272
|
padded_num_experts = align_to(num_experts, 128)
|
|
1273
|
+
padded_top_k = align_to(top_k, 128)
|
|
1202
1274
|
t_dtype = tokens.dtype
|
|
1203
1275
|
t_packing = get_dtype_packing(t_dtype)
|
|
1204
1276
|
|
|
1277
|
+
# Override bt
|
|
1278
|
+
if local_num_tokens <= t_packing * 8:
|
|
1279
|
+
bt = local_num_tokens
|
|
1280
|
+
btc = bt
|
|
1281
|
+
bt = min(local_num_tokens, bt)
|
|
1282
|
+
# The worst case is that all devices send bt to one device.
|
|
1283
|
+
btc = min(bt, btc, bt * num_devices)
|
|
1284
|
+
|
|
1285
|
+
if local_num_tokens % t_packing != 0:
|
|
1286
|
+
raise ValueError(
|
|
1287
|
+
f"Expected {local_num_tokens=} to be aligned to {t_packing=}.")
|
|
1288
|
+
|
|
1289
|
+
if bt % t_packing != 0:
|
|
1290
|
+
raise ValueError(f"Expected {bt=} to be aligned to {t_packing=}.")
|
|
1291
|
+
if local_num_tokens % bt != 0:
|
|
1292
|
+
raise ValueError(
|
|
1293
|
+
f"Expected {local_num_tokens=} to be aligned to {bt=}.")
|
|
1294
|
+
|
|
1205
1295
|
if subc_quant_wsz is not None:
|
|
1296
|
+
if subc_quant_wsz <= 0:
|
|
1297
|
+
raise ValueError(f"Expected {subc_quant_wsz=} to be non-negative.")
|
|
1206
1298
|
if subc_quant_wsz % 256 != 0:
|
|
1207
|
-
raise
|
|
1208
|
-
"
|
|
1209
|
-
|
|
1299
|
+
raise ValueError(
|
|
1300
|
+
"Expected {subc_quant_wsz=} to be aligned to 256.")
|
|
1301
|
+
if hidden_size % subc_quant_wsz != 0:
|
|
1302
|
+
raise ValueError(
|
|
1303
|
+
f"Expected {hidden_size=} to be aligned to {subc_quant_wsz=}.")
|
|
1304
|
+
if intermediate_size % subc_quant_wsz != 0:
|
|
1305
|
+
raise ValueError(
|
|
1306
|
+
f"Expected {intermediate_size=} to be aligned to {subc_quant_wsz=}."
|
|
1307
|
+
)
|
|
1308
|
+
# We force compute size of contracting dim to be subc_quant_wsz. So we can
|
|
1210
1309
|
# apply same scale after matmul and accumulation.
|
|
1211
1310
|
bd1c = subc_quant_wsz * t_packing
|
|
1212
1311
|
bfc = subc_quant_wsz
|
|
1213
1312
|
|
|
1214
|
-
|
|
1215
|
-
|
|
1216
|
-
|
|
1217
|
-
|
|
1218
|
-
|
|
1219
|
-
|
|
1220
|
-
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1313
|
+
if bfc % 128 != 0:
|
|
1314
|
+
raise ValueError(f"Expected {bfc=} to be aligned to 128.")
|
|
1315
|
+
if bd1c % (t_packing * 128) != 0:
|
|
1316
|
+
raise ValueError(
|
|
1317
|
+
f"Expected {bd1c=} to be aligned to {t_packing * 128}.")
|
|
1318
|
+
if bd2c % (t_packing * 128) != 0:
|
|
1319
|
+
raise ValueError(
|
|
1320
|
+
f"Expected {bd2c=} to be aligned to {t_packing * 128}.")
|
|
1321
|
+
if bf % bfc != 0:
|
|
1322
|
+
raise ValueError(f"Expected {bf=} to be aligned to {bfc=}.")
|
|
1323
|
+
if bd1 % bd1c != 0:
|
|
1324
|
+
raise ValueError(f"Expected {bd1=} to be aligned to {bd1c=}.")
|
|
1325
|
+
if bd2 % bd2c != 0:
|
|
1326
|
+
raise ValueError(f"Expected {bd2=} to be aligned to {bd2c=}.")
|
|
1327
|
+
if hidden_size % bd1 != 0 or hidden_size % bd2 != 0:
|
|
1328
|
+
raise ValueError(
|
|
1329
|
+
f"Expected {hidden_size=} to be aligned to {bd1=} and {bd2=}.")
|
|
1330
|
+
if intermediate_size % bf != 0:
|
|
1331
|
+
raise ValueError(
|
|
1332
|
+
f"Expected {intermediate_size=} to be aligned to {bf=}.")
|
|
1333
|
+
|
|
1334
|
+
# Note: we should dump scale as the kernel expected shape in the
|
|
1230
1335
|
# checkpoint offline or reshape right after weight loading.
|
|
1231
1336
|
if w1_scale is not None:
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1337
|
+
expected_w1_scale_shape = (
|
|
1338
|
+
num_experts,
|
|
1339
|
+
2,
|
|
1340
|
+
hidden_size // subc_quant_wsz,
|
|
1341
|
+
1,
|
|
1342
|
+
intermediate_size,
|
|
1343
|
+
)
|
|
1344
|
+
if w1_scale.shape != expected_w1_scale_shape:
|
|
1345
|
+
raise ValueError(
|
|
1346
|
+
f"Expected {w1_scale.shape=} to be {expected_w1_scale_shape}.")
|
|
1347
|
+
if w1_scale.dtype != jnp.float32:
|
|
1348
|
+
w1_scale = w1_scale.astype(jnp.float32)
|
|
1237
1349
|
|
|
1238
1350
|
if w2_scale is not None:
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1351
|
+
expected_w2_scale_shape = (
|
|
1352
|
+
num_experts,
|
|
1353
|
+
intermediate_size // subc_quant_wsz,
|
|
1354
|
+
1,
|
|
1355
|
+
hidden_size,
|
|
1356
|
+
)
|
|
1357
|
+
if w2_scale.shape != expected_w2_scale_shape:
|
|
1358
|
+
raise ValueError(
|
|
1359
|
+
f"Expected {w2_scale.shape=} to be {expected_w2_scale_shape}.")
|
|
1360
|
+
if w2_scale.dtype != jnp.float32:
|
|
1361
|
+
w2_scale = w2_scale.astype(jnp.float32)
|
|
1243
1362
|
|
|
1244
1363
|
if b1 is not None:
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1364
|
+
expected_b1_shape = (num_experts, 2, 1, intermediate_size)
|
|
1365
|
+
if b1.shape != expected_b1_shape:
|
|
1366
|
+
raise ValueError(
|
|
1367
|
+
f"Expected {b1.shape=} to be {expected_b1_shape}.")
|
|
1368
|
+
if b1.dtype != jnp.float32:
|
|
1369
|
+
b1 = b1.astype(jnp.float32)
|
|
1249
1370
|
|
|
1250
1371
|
if b2 is not None:
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1372
|
+
expected_b2_shape = (num_experts, 1, hidden_size)
|
|
1373
|
+
if b2.shape != expected_b2_shape:
|
|
1374
|
+
raise ValueError(
|
|
1375
|
+
f"Expected {b2.shape=} to be {expected_b2_shape}.")
|
|
1376
|
+
if b2.dtype != jnp.float32:
|
|
1377
|
+
b2 = b2.astype(jnp.float32)
|
|
1254
1378
|
|
|
1255
1379
|
# Prepare inputs for the kernel.
|
|
1256
1380
|
if padded_num_experts != gating_output.shape[-1]:
|
|
@@ -1260,248 +1384,171 @@ def fused_ep_moe(
|
|
|
1260
1384
|
constant_values=-jnp.inf,
|
|
1261
1385
|
)
|
|
1262
1386
|
|
|
1263
|
-
if (hidden_size != actual_hidden_size
|
|
1264
|
-
or intermediate_size != actual_intermediate_size):
|
|
1265
|
-
tokens = jnp.pad(
|
|
1266
|
-
tokens,
|
|
1267
|
-
((0, 0), (0, hidden_size - actual_hidden_size)),
|
|
1268
|
-
constant_values=0,
|
|
1269
|
-
)
|
|
1270
|
-
w1 = jnp.pad(
|
|
1271
|
-
w1,
|
|
1272
|
-
(
|
|
1273
|
-
(0, 0),
|
|
1274
|
-
(0, 0),
|
|
1275
|
-
(0, hidden_size - actual_hidden_size),
|
|
1276
|
-
(0, intermediate_size - actual_intermediate_size),
|
|
1277
|
-
),
|
|
1278
|
-
constant_values=0,
|
|
1279
|
-
)
|
|
1280
|
-
w2 = jnp.pad(
|
|
1281
|
-
w2,
|
|
1282
|
-
(
|
|
1283
|
-
(0, 0),
|
|
1284
|
-
(0, intermediate_size - actual_intermediate_size),
|
|
1285
|
-
(0, hidden_size - actual_hidden_size),
|
|
1286
|
-
),
|
|
1287
|
-
constant_values=0,
|
|
1288
|
-
)
|
|
1289
|
-
if w1_scale is not None:
|
|
1290
|
-
w1_scale = jnp.pad(
|
|
1291
|
-
w1_scale,
|
|
1292
|
-
(
|
|
1293
|
-
(0, 0),
|
|
1294
|
-
(0, 0),
|
|
1295
|
-
(0,
|
|
1296
|
-
cdiv(hidden_size, subc_quant_wsz) - w1_scale.shape[-3]),
|
|
1297
|
-
(0, 0),
|
|
1298
|
-
(0, intermediate_size - w1_scale.shape[-1]),
|
|
1299
|
-
),
|
|
1300
|
-
constant_values=0,
|
|
1301
|
-
)
|
|
1302
|
-
if w2_scale is not None:
|
|
1303
|
-
w2_scale = jnp.pad(
|
|
1304
|
-
w2_scale,
|
|
1305
|
-
(
|
|
1306
|
-
(0, 0),
|
|
1307
|
-
(0, cdiv(intermediate_size, subc_quant_wsz) -
|
|
1308
|
-
w2_scale.shape[-3]),
|
|
1309
|
-
(0, 0),
|
|
1310
|
-
(0, hidden_size - w2_scale.shape[-1]),
|
|
1311
|
-
),
|
|
1312
|
-
constant_values=0,
|
|
1313
|
-
)
|
|
1314
|
-
if b1 is not None:
|
|
1315
|
-
b1 = jnp.pad(
|
|
1316
|
-
b1,
|
|
1317
|
-
(
|
|
1318
|
-
(0, 0),
|
|
1319
|
-
(0, 0),
|
|
1320
|
-
(0, 0),
|
|
1321
|
-
(0, intermediate_size - b1.shape[-1]),
|
|
1322
|
-
),
|
|
1323
|
-
constant_values=0,
|
|
1324
|
-
)
|
|
1325
|
-
if b2 is not None:
|
|
1326
|
-
b2 = jnp.pad(
|
|
1327
|
-
b2,
|
|
1328
|
-
(
|
|
1329
|
-
(0, 0),
|
|
1330
|
-
(0, 0),
|
|
1331
|
-
(0, hidden_size - b2.shape[-1]),
|
|
1332
|
-
),
|
|
1333
|
-
constant_values=0,
|
|
1334
|
-
)
|
|
1335
|
-
|
|
1336
1387
|
tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing)
|
|
1337
1388
|
|
|
1338
1389
|
hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM)
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
|
|
1352
|
-
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
|
|
1356
|
-
|
|
1357
|
-
|
|
1358
|
-
|
|
1359
|
-
|
|
1360
|
-
|
|
1361
|
-
|
|
1362
|
-
|
|
1363
|
-
|
|
1364
|
-
|
|
1365
|
-
|
|
1366
|
-
|
|
1367
|
-
|
|
1368
|
-
|
|
1369
|
-
|
|
1370
|
-
|
|
1371
|
-
|
|
1372
|
-
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
|
|
1379
|
-
|
|
1380
|
-
|
|
1381
|
-
|
|
1382
|
-
|
|
1383
|
-
|
|
1384
|
-
|
|
1385
|
-
|
|
1386
|
-
|
|
1387
|
-
|
|
1388
|
-
|
|
1389
|
-
|
|
1390
|
-
|
|
1391
|
-
|
|
1392
|
-
|
|
1393
|
-
|
|
1394
|
-
bt * num_devices,
|
|
1395
|
-
t_packing,
|
|
1396
|
-
hidden_size // t_packing,
|
|
1397
|
-
),
|
|
1398
|
-
t_dtype,
|
|
1390
|
+
renorm_str = "-renorm_k" if renormalize_topk_logits else ""
|
|
1391
|
+
scope_name = f"fused-moe-k_{top_k}{renorm_str}-bt_{bt}_{btc}-bf_{bf}_{bfc}-bd1_{bd1}_{bd1c}-bd2_{bd2}_{bd2c}"
|
|
1392
|
+
fused_moe = pl.pallas_call(
|
|
1393
|
+
functools.partial(
|
|
1394
|
+
_fused_ep_moe_kernel,
|
|
1395
|
+
top_k=top_k,
|
|
1396
|
+
renormalize_topk_logits=renormalize_topk_logits,
|
|
1397
|
+
ep_axis_name=ep_axis_name,
|
|
1398
|
+
act_fn=act_fn,
|
|
1399
|
+
subc_quant_wsz=subc_quant_wsz,
|
|
1400
|
+
bt=bt,
|
|
1401
|
+
bf=bf,
|
|
1402
|
+
bd1=bd1,
|
|
1403
|
+
bd2=bd2,
|
|
1404
|
+
btc=btc,
|
|
1405
|
+
bfc=bfc,
|
|
1406
|
+
bd1c=bd1c,
|
|
1407
|
+
bd2c=bd2c,
|
|
1408
|
+
),
|
|
1409
|
+
out_shape=jax.ShapeDtypeStruct((local_num_tokens, hidden_size),
|
|
1410
|
+
t_dtype),
|
|
1411
|
+
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
1412
|
+
num_scalar_prefetch=0,
|
|
1413
|
+
in_specs=[
|
|
1414
|
+
hbm_block_spec, # tokens_hbm
|
|
1415
|
+
hbm_block_spec, # w1_hbm
|
|
1416
|
+
hbm_block_spec, # w2_hbm
|
|
1417
|
+
None if w1_scale is None else hbm_block_spec, # w1_scale_hbm
|
|
1418
|
+
None if w2_scale is None else hbm_block_spec, # w2_scale_hbm
|
|
1419
|
+
None if b1 is None else hbm_block_spec, # b1_hbm
|
|
1420
|
+
None if b2 is None else hbm_block_spec, # b2_hbm
|
|
1421
|
+
hbm_block_spec, # gating_output_hbm
|
|
1422
|
+
hbm_block_spec, # a2a_g_hbm
|
|
1423
|
+
],
|
|
1424
|
+
out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
|
|
1425
|
+
scratch_shapes=([
|
|
1426
|
+
# t2e_routing_x2_smem
|
|
1427
|
+
pltpu.SMEM((2, bt, padded_top_k), jnp.int32),
|
|
1428
|
+
# d2e_count_x2_smem
|
|
1429
|
+
pltpu.SMEM((2, num_devices, 1, padded_num_experts), jnp.int32),
|
|
1430
|
+
# expert_offsets_x2_smem
|
|
1431
|
+
pltpu.SMEM((2, 2, padded_num_experts), jnp.int32),
|
|
1432
|
+
# expert_starts_x2_smem
|
|
1433
|
+
pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
|
|
1434
|
+
# expert_sizes_x2_smem
|
|
1435
|
+
pltpu.SMEM((2, 1, padded_num_experts), jnp.int32),
|
|
1436
|
+
# a2a_s_sends_x2_smem
|
|
1437
|
+
pltpu.SMEM((2, ), jnp.int32),
|
|
1438
|
+
# a2a_s_x2_vmem
|
|
1439
|
+
pltpu.VMEM(
|
|
1440
|
+
(
|
|
1441
|
+
2,
|
|
1442
|
+
bt * num_devices,
|
|
1443
|
+
t_packing,
|
|
1444
|
+
hidden_size // t_packing,
|
|
1399
1445
|
),
|
|
1400
|
-
|
|
1401
|
-
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
|
|
1405
|
-
|
|
1406
|
-
|
|
1407
|
-
|
|
1408
|
-
|
|
1446
|
+
t_dtype,
|
|
1447
|
+
),
|
|
1448
|
+
# a2a_s_acc_x2_vmem
|
|
1449
|
+
pltpu.VMEM(
|
|
1450
|
+
(
|
|
1451
|
+
2,
|
|
1452
|
+
bt * num_devices,
|
|
1453
|
+
t_packing,
|
|
1454
|
+
hidden_size // t_packing,
|
|
1409
1455
|
),
|
|
1410
|
-
|
|
1411
|
-
|
|
1412
|
-
|
|
1413
|
-
|
|
1414
|
-
|
|
1415
|
-
|
|
1416
|
-
|
|
1417
|
-
|
|
1418
|
-
|
|
1419
|
-
|
|
1420
|
-
|
|
1421
|
-
|
|
1422
|
-
|
|
1423
|
-
|
|
1424
|
-
|
|
1425
|
-
|
|
1426
|
-
|
|
1427
|
-
|
|
1428
|
-
|
|
1429
|
-
|
|
1430
|
-
|
|
1431
|
-
|
|
1432
|
-
|
|
1433
|
-
|
|
1434
|
-
|
|
1435
|
-
|
|
1436
|
-
|
|
1437
|
-
|
|
1438
|
-
|
|
1439
|
-
|
|
1440
|
-
|
|
1441
|
-
|
|
1442
|
-
|
|
1443
|
-
|
|
1444
|
-
|
|
1445
|
-
|
|
1446
|
-
|
|
1447
|
-
|
|
1448
|
-
|
|
1449
|
-
|
|
1450
|
-
|
|
1451
|
-
|
|
1452
|
-
|
|
1453
|
-
|
|
1454
|
-
|
|
1455
|
-
|
|
1456
|
-
|
|
1457
|
-
|
|
1458
|
-
|
|
1459
|
-
|
|
1460
|
-
|
|
1461
|
-
|
|
1462
|
-
|
|
1463
|
-
|
|
1464
|
-
|
|
1465
|
-
|
|
1466
|
-
|
|
1467
|
-
|
|
1468
|
-
|
|
1469
|
-
|
|
1470
|
-
|
|
1471
|
-
|
|
1472
|
-
|
|
1473
|
-
|
|
1474
|
-
|
|
1475
|
-
|
|
1476
|
-
|
|
1477
|
-
|
|
1478
|
-
|
|
1479
|
-
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
|
|
1484
|
-
|
|
1485
|
-
|
|
1486
|
-
|
|
1487
|
-
|
|
1488
|
-
|
|
1489
|
-
|
|
1490
|
-
|
|
1491
|
-
|
|
1492
|
-
|
|
1493
|
-
|
|
1494
|
-
|
|
1495
|
-
|
|
1496
|
-
|
|
1497
|
-
|
|
1498
|
-
),
|
|
1499
|
-
|
|
1500
|
-
|
|
1501
|
-
|
|
1502
|
-
|
|
1503
|
-
|
|
1504
|
-
|
|
1456
|
+
t_dtype,
|
|
1457
|
+
),
|
|
1458
|
+
# a2a_g_acc_vmem
|
|
1459
|
+
pltpu.VMEM((top_k, bt, t_packing, hidden_size // t_packing),
|
|
1460
|
+
t_dtype),
|
|
1461
|
+
# b_gating_x2_vmem
|
|
1462
|
+
pltpu.VMEM((2, bt, padded_num_experts), t_dtype),
|
|
1463
|
+
# b_output_x2_vmem
|
|
1464
|
+
pltpu.VMEM((2, bt, hidden_size), t_dtype),
|
|
1465
|
+
# b_w1_x2_vmem
|
|
1466
|
+
pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
|
|
1467
|
+
# b_w3_x2_vmem
|
|
1468
|
+
pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
|
|
1469
|
+
# b_w2_x2_vmem
|
|
1470
|
+
pltpu.VMEM((2, t_packing, bf, bd2 // t_packing), w2.dtype),
|
|
1471
|
+
# b_w1_scale_x2_vmem
|
|
1472
|
+
(None if w1_scale is None else pltpu.VMEM(
|
|
1473
|
+
(
|
|
1474
|
+
2,
|
|
1475
|
+
t_packing,
|
|
1476
|
+
bd1 // t_packing // subc_quant_wsz,
|
|
1477
|
+
1,
|
|
1478
|
+
bf,
|
|
1479
|
+
),
|
|
1480
|
+
jnp.float32,
|
|
1481
|
+
)),
|
|
1482
|
+
# b_w3_scale_x2_vmem
|
|
1483
|
+
(None if w1_scale is None else pltpu.VMEM(
|
|
1484
|
+
(
|
|
1485
|
+
2,
|
|
1486
|
+
t_packing,
|
|
1487
|
+
bd1 // t_packing // subc_quant_wsz,
|
|
1488
|
+
1,
|
|
1489
|
+
bf,
|
|
1490
|
+
),
|
|
1491
|
+
jnp.float32,
|
|
1492
|
+
)),
|
|
1493
|
+
# b_w2_scale_x2_vmem
|
|
1494
|
+
(None if w2_scale is None else pltpu.VMEM(
|
|
1495
|
+
(
|
|
1496
|
+
2,
|
|
1497
|
+
t_packing,
|
|
1498
|
+
bf // subc_quant_wsz,
|
|
1499
|
+
1,
|
|
1500
|
+
bd2 // t_packing,
|
|
1501
|
+
),
|
|
1502
|
+
jnp.float32,
|
|
1503
|
+
)),
|
|
1504
|
+
# b_b1_x2_vmem
|
|
1505
|
+
(None if b1 is None else pltpu.VMEM(
|
|
1506
|
+
(
|
|
1507
|
+
2,
|
|
1508
|
+
1,
|
|
1509
|
+
bf,
|
|
1510
|
+
),
|
|
1511
|
+
jnp.float32,
|
|
1512
|
+
)),
|
|
1513
|
+
# b_b3_x2_vmem
|
|
1514
|
+
(None if b1 is None else pltpu.VMEM(
|
|
1515
|
+
(
|
|
1516
|
+
2,
|
|
1517
|
+
1,
|
|
1518
|
+
bf,
|
|
1519
|
+
),
|
|
1520
|
+
jnp.float32,
|
|
1521
|
+
)),
|
|
1522
|
+
# b_b2_x2_vmem
|
|
1523
|
+
(None if b2 is None else pltpu.VMEM(
|
|
1524
|
+
(
|
|
1525
|
+
2,
|
|
1526
|
+
t_packing,
|
|
1527
|
+
1,
|
|
1528
|
+
bd2 // t_packing,
|
|
1529
|
+
),
|
|
1530
|
+
jnp.float32,
|
|
1531
|
+
)),
|
|
1532
|
+
# b_acc_vmem
|
|
1533
|
+
pltpu.VMEM((bt * num_devices, 1, bf * 2), jnp.float32),
|
|
1534
|
+
# local_sems
|
|
1535
|
+
pltpu.SemaphoreType.DMA((2, 5)),
|
|
1536
|
+
# send_sems
|
|
1537
|
+
pltpu.SemaphoreType.DMA((2, )),
|
|
1538
|
+
# recv_sems
|
|
1539
|
+
pltpu.SemaphoreType.DMA((2, )),
|
|
1540
|
+
# a2a_gather_sem
|
|
1541
|
+
pltpu.SemaphoreType.DMA,
|
|
1542
|
+
# a2a_acc_sem
|
|
1543
|
+
pltpu.SemaphoreType.DMA,
|
|
1544
|
+
]),
|
|
1545
|
+
),
|
|
1546
|
+
compiler_params=pltpu.CompilerParams(
|
|
1547
|
+
collective_id=0,
|
|
1548
|
+
vmem_limit_bytes=100 * 1024 * 1024,
|
|
1549
|
+
),
|
|
1550
|
+
name=scope_name,
|
|
1551
|
+
)
|
|
1505
1552
|
|
|
1506
1553
|
@jax.jit
|
|
1507
1554
|
@jax.shard_map(
|
|
@@ -1552,7 +1599,7 @@ def fused_ep_moe(
|
|
|
1552
1599
|
|
|
1553
1600
|
a2a_g_hbm_scratch = pl.empty(
|
|
1554
1601
|
(num_experts, bt, t_packing, hidden_size // t_packing), t_dtype)
|
|
1555
|
-
|
|
1602
|
+
return kernel(
|
|
1556
1603
|
tokens,
|
|
1557
1604
|
w1,
|
|
1558
1605
|
w2,
|
|
@@ -1563,4 +1610,3 @@ def fused_ep_moe(
|
|
|
1563
1610
|
gating_output,
|
|
1564
1611
|
a2a_g_hbm_scratch,
|
|
1565
1612
|
)
|
|
1566
|
-
return results[:, :actual_hidden_size]
|