tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/kernels/mla_v1_test.py +129 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- tests/lora/test_layers.py +4 -1
- tests/lora/test_lora_perf.py +53 -0
- tests/test_envs.py +110 -12
- tests/test_quantization.py +3 -0
- tests/test_utils.py +1 -2
- tpu_inference/distributed/tpu_connector.py +1 -1
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/ray_distributed_executor.py +5 -1
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
- tpu_inference/kernels/mla/v1/kernel.py +98 -120
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +82 -32
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +146 -85
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +11 -7
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +170 -208
- tpu_inference/layers/vllm/linear_common.py +43 -21
- tpu_inference/layers/vllm/quantization/common.py +11 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
- tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
- tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
- tpu_inference/models/common/model_loader.py +78 -22
- tpu_inference/models/jax/deepseek_v3.py +185 -64
- tpu_inference/models/jax/gpt_oss.py +3 -3
- tpu_inference/models/jax/llama_eagle3.py +4 -5
- tpu_inference/models/jax/qwen2_5_vl.py +161 -47
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
- tpu_inference/models/jax/utils/weight_utils.py +203 -155
- tpu_inference/models/vllm/vllm_model_wrapper.py +11 -5
- tpu_inference/platforms/tpu_platform.py +29 -48
- tpu_inference/runner/compilation_manager.py +112 -46
- tpu_inference/runner/kv_cache.py +40 -20
- tpu_inference/runner/kv_cache_manager.py +40 -31
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +94 -51
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -22
- tpu_inference/utils.py +41 -14
- tpu_inference/worker/tpu_worker.py +43 -45
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +8 -9
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +59 -58
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
|
@@ -7,7 +7,6 @@ import jax.numpy as jnp
|
|
|
7
7
|
from jax import lax
|
|
8
8
|
from jax._src import dtypes
|
|
9
9
|
from jax.experimental import pallas as pl
|
|
10
|
-
from jax.experimental import shard_map
|
|
11
10
|
from jax.experimental.pallas import tpu as pltpu
|
|
12
11
|
|
|
13
12
|
P = jax.sharding.PartitionSpec
|
|
@@ -20,7 +19,8 @@ def align_to(x, a):
|
|
|
20
19
|
|
|
21
20
|
|
|
22
21
|
def get_dtype_packing(dtype):
|
|
23
|
-
bits = dtypes.bit_width(dtype)
|
|
22
|
+
bits = (dtypes.bit_width(dtype)
|
|
23
|
+
if hasattr(dtypes, "bit_width") else dtypes.itemsize_bits(dtype))
|
|
24
24
|
return 32 // bits
|
|
25
25
|
|
|
26
26
|
|
|
@@ -35,13 +35,50 @@ def broadcast_minor(src, shape):
|
|
|
35
35
|
axis=-1)[..., :shape[-1]]
|
|
36
36
|
|
|
37
37
|
|
|
38
|
+
def swigluoai(gate: jax.Array,
|
|
39
|
+
up: jax.Array,
|
|
40
|
+
*,
|
|
41
|
+
alpha: float = 1.702,
|
|
42
|
+
limit: float = 7.0) -> jax.Array:
|
|
43
|
+
"""Activation used in some models such as GPT-OSS."""
|
|
44
|
+
gate = jnp.clip(gate, a_max=limit)
|
|
45
|
+
up = jnp.clip(up, a_min=-limit, a_max=limit)
|
|
46
|
+
glu = gate * jax.nn.sigmoid(alpha * gate)
|
|
47
|
+
return (up + 1.0) * glu
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def activation_fn(acc1, acc3, act_fn):
|
|
51
|
+
if act_fn == "silu":
|
|
52
|
+
return jax.nn.silu(acc1) * acc3
|
|
53
|
+
elif act_fn == "gelu":
|
|
54
|
+
return jax.nn.gelu(acc1) * acc3
|
|
55
|
+
elif act_fn == "swigluoai":
|
|
56
|
+
return swigluoai(acc1, acc3)
|
|
57
|
+
else:
|
|
58
|
+
raise RuntimeError(f"Unsupported activation function: {act_fn}")
|
|
59
|
+
|
|
60
|
+
|
|
38
61
|
def ref_moe(
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
62
|
+
tokens: jax.Array, # (num_tokens, hidden_size)
|
|
63
|
+
w1: jax.Array, # (num_experts, 2, hidden_size, intermediate_size)
|
|
64
|
+
w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
|
|
65
|
+
gating_output: jax.Array, # (num_tokens, num_experts)
|
|
66
|
+
top_k: int,
|
|
67
|
+
*,
|
|
68
|
+
renormalize_topk_logits: bool = False,
|
|
69
|
+
act_fn: str = "silu",
|
|
70
|
+
subc_quant_wsz: int | None = None,
|
|
71
|
+
w1_scale:
|
|
72
|
+
(
|
|
73
|
+
jax.Array | None
|
|
74
|
+
) = None, # F32(num_experts, 2, hidden_size //subc_quant_wsz, 1, intermediate_size)
|
|
75
|
+
w2_scale:
|
|
76
|
+
(
|
|
77
|
+
jax.Array | None
|
|
78
|
+
) = None, # F32(num_experts, intermediate_size // subc_quant_wsz, 1, hidden_size)
|
|
79
|
+
b1: jax.Array
|
|
80
|
+
| None = None, # F32(num_experts, 2, 1, intermediate_size)
|
|
81
|
+
b2: jax.Array | None = None, # F32(num_experts, 1, hidden_size)
|
|
45
82
|
):
|
|
46
83
|
n_tokens = tokens.shape[0] # num_tokens
|
|
47
84
|
|
|
@@ -53,11 +90,16 @@ def ref_moe(
|
|
|
53
90
|
top_k_logits, top_k_indices = lax.top_k(
|
|
54
91
|
gating_logits, top_k) # [num_tokens, top_k], [num_tokens, top_k]
|
|
55
92
|
|
|
93
|
+
if renormalize_topk_logits:
|
|
94
|
+
top_k_logits = top_k_logits / jnp.sum(
|
|
95
|
+
top_k_logits, axis=-1, keepdims=True)
|
|
96
|
+
|
|
56
97
|
t_outputs = []
|
|
98
|
+
hidden_size, intermediate_size = w1.shape[-2:]
|
|
57
99
|
|
|
58
100
|
# Process each token individually
|
|
59
101
|
for i in range(n_tokens):
|
|
60
|
-
curr_token = jnp.expand_dims(tokens[i], axis=0) # [1,
|
|
102
|
+
curr_token = jnp.expand_dims(tokens[i], axis=0) # [1, hidden_size]
|
|
61
103
|
assigned_expert_ids = top_k_indices[
|
|
62
104
|
i] # [top_k] - indices of selected experts for token i
|
|
63
105
|
tok_expert_act = []
|
|
@@ -65,10 +107,24 @@ def ref_moe(
|
|
|
65
107
|
# Process each selected expert for the current token
|
|
66
108
|
for expert_id in assigned_expert_ids:
|
|
67
109
|
# Get expert weights
|
|
110
|
+
expert_w1 = w1[expert_id, 0].astype(jnp.float32)
|
|
111
|
+
expert_w3 = w1[expert_id, 1].astype(jnp.float32)
|
|
112
|
+
if w1_scale is not None:
|
|
113
|
+
expert_w1 *= jnp.repeat(w1_scale[expert_id, 0, :, 0],
|
|
114
|
+
subc_quant_wsz,
|
|
115
|
+
axis=0)[:hidden_size]
|
|
116
|
+
expert_w3 *= jnp.repeat(w1_scale[expert_id, 1, :, 0],
|
|
117
|
+
subc_quant_wsz,
|
|
118
|
+
axis=0)[:hidden_size]
|
|
68
119
|
expert_weight_1 = jnp.concat(
|
|
69
|
-
[
|
|
70
|
-
axis=-1) # [
|
|
71
|
-
expert_weight_2 = w2[expert_id]
|
|
120
|
+
[expert_w1, expert_w3],
|
|
121
|
+
axis=-1) # [hidden_size, 2 * intermediate_size]
|
|
122
|
+
expert_weight_2 = w2[expert_id].astype(
|
|
123
|
+
jnp.float32) # [intermediate_size, hidden_size]
|
|
124
|
+
if w2_scale is not None:
|
|
125
|
+
expert_weight_2 *= jnp.repeat(w2_scale[expert_id, :, 0],
|
|
126
|
+
subc_quant_wsz,
|
|
127
|
+
axis=0)[:intermediate_size]
|
|
72
128
|
|
|
73
129
|
# First linear layer with SwiGLU activation
|
|
74
130
|
gmm_1_out = curr_token @ expert_weight_1 # [1, 2 * intermediate_size]
|
|
@@ -77,37 +133,34 @@ def ref_moe(
|
|
|
77
133
|
gmm1_w1_proj, gmm1_w3_proj = jnp.split(
|
|
78
134
|
gmm_1_out, 2,
|
|
79
135
|
axis=-1) # [1, intermediate_size], [1, intermediate_size]
|
|
136
|
+
if b1 is not None:
|
|
137
|
+
gmm1_w1_proj += b1[expert_id:expert_id + 1, 0, 0]
|
|
138
|
+
gmm1_w3_proj += b1[expert_id:expert_id + 1, 1, 0]
|
|
80
139
|
|
|
81
140
|
# Apply gated activation: activation(gate) * up
|
|
82
|
-
|
|
83
|
-
act = jax.nn.silu(
|
|
84
|
-
gmm1_w1_proj) * gmm1_w3_proj # [1, intermediate_size]
|
|
85
|
-
elif activation == "gelu":
|
|
86
|
-
act = jax.nn.gelu(
|
|
87
|
-
gmm1_w1_proj) * gmm1_w3_proj # [1, intermediate_size]
|
|
88
|
-
else:
|
|
89
|
-
raise ValueError(
|
|
90
|
-
f"Unsupported activation: {activation}. Use 'silu' or 'gelu'."
|
|
91
|
-
)
|
|
141
|
+
act = activation_fn(gmm1_w1_proj, gmm1_w3_proj, act_fn)
|
|
92
142
|
|
|
93
143
|
# Second linear layer (down projection)
|
|
94
|
-
gmm_2_out = act @ expert_weight_2 # [1,
|
|
144
|
+
gmm_2_out = act @ expert_weight_2 # [1, hidden_size]
|
|
145
|
+
if b2 is not None:
|
|
146
|
+
gmm_2_out += b2[expert_id:expert_id + 1, 0]
|
|
95
147
|
tok_expert_act.append(gmm_2_out)
|
|
96
148
|
|
|
97
149
|
# Combine outputs from all selected experts
|
|
98
150
|
experts_act = jnp.concatenate(tok_expert_act,
|
|
99
|
-
axis=0) # [top_k,
|
|
151
|
+
axis=0) # [top_k, hidden_size]
|
|
100
152
|
|
|
101
153
|
# Weighted sum using top-k gating weights
|
|
102
154
|
top_k_weights = top_k_logits[i] # [top_k]
|
|
103
155
|
top_k_weights = jnp.expand_dims(top_k_weights, axis=1) # [top_k, 1]
|
|
104
156
|
weighted_output = jnp.sum(experts_act * top_k_weights,
|
|
105
157
|
axis=0,
|
|
106
|
-
keepdims=True) # [1,
|
|
158
|
+
keepdims=True) # [1, hidden_size]
|
|
107
159
|
|
|
108
|
-
t_outputs.append(weighted_output)
|
|
160
|
+
t_outputs.append(weighted_output.astype(tokens.dtype))
|
|
109
161
|
|
|
110
|
-
return jnp.concatenate(t_outputs,
|
|
162
|
+
return jnp.concatenate(t_outputs,
|
|
163
|
+
axis=0) # [actual_num_tokens, hidden_size]
|
|
111
164
|
|
|
112
165
|
|
|
113
166
|
def _fused_ep_moe_kernel(
|
|
@@ -115,12 +168,19 @@ def _fused_ep_moe_kernel(
|
|
|
115
168
|
tokens_hbm, # (local_num_tokens, t_packing, hidden_size // t_packing)
|
|
116
169
|
w1_hbm, # (local_num_experts, 2, hidden_size, intermediate_size)
|
|
117
170
|
w2_hbm, # (local_num_experts, intermediate_size, hidden_size)
|
|
171
|
+
# TODO(jevinjiang): We choose F32 scale for easier slicing. The extra
|
|
172
|
+
# latency should be hidden in the pipeline overlaping. But is there a better
|
|
173
|
+
# way to do this?
|
|
174
|
+
w1_scale_hbm, # None | F32(local_num_experts, 2, cdiv(hidden_size, subc_quant_wsz), 1, intermediate_size)
|
|
175
|
+
w2_scale_hbm, # None | F32(local_num_experts, cdiv(intermediate_size, subc_quant_wsz), 1, hidden_size)
|
|
176
|
+
b1_hbm, # None | F32(local_num_experts, 2, 1, intermediate_size)
|
|
177
|
+
b2_hbm, # None | F32(local_num_experts, 1, hidden_size)
|
|
118
178
|
gating_hbm, # (local_num_tokens, padded_num_experts)
|
|
119
179
|
a2a_g_hbm, # (num_experts, bt, t_packing, hidden_size // t_packing)
|
|
120
180
|
# Output
|
|
121
181
|
output_hbm, # (local_num_tokens, hidden_size)
|
|
122
182
|
# Scratch
|
|
123
|
-
t2e_routing_x2_smem, # <bt_sem_id> (2, bt,
|
|
183
|
+
t2e_routing_x2_smem, # <bt_sem_id> (2, bt, padded_top_k)
|
|
124
184
|
d2e_count_x2_smem, # <bt_sem_id> (2, num_devices, 1, padded_num_experts)
|
|
125
185
|
expert_offsets_x2_smem, # <bt_sem_id> (2, 2, padded_num_experts): for a2a_s and a2a_g
|
|
126
186
|
expert_starts_x2_smem, # <bt_sem_id> (2, 1, padded_num_experts)
|
|
@@ -136,6 +196,12 @@ def _fused_ep_moe_kernel(
|
|
|
136
196
|
b_w1_x2_vmem, # <bw_sem_id> (2, t_packing, bd1 // t_packing, bf)
|
|
137
197
|
b_w3_x2_vmem, # <bw_sem_id> (2, t_packing, bd1 // t_packing, bf)
|
|
138
198
|
b_w2_x2_vmem, # <bw_sem_id> (2, t_packing, bf, bd2 // t_packing)
|
|
199
|
+
b_w1_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf)
|
|
200
|
+
b_w3_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf)
|
|
201
|
+
b_w2_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bf // subc_quant_wsz, 1, bd2 // t_packing)
|
|
202
|
+
b_b1_x2_vmem, # None | <bw_sem_id> (2, 1, bf)
|
|
203
|
+
b_b3_x2_vmem, # None | <bw_sem_id> (2, 1, bf)
|
|
204
|
+
b_b2_x2_vmem, # None | <bw_sem_id> (2, t_packing, 1, bd2 // t_packing)
|
|
139
205
|
b_acc_vmem, # F32(bt * num_devices, 1, bf * 2)
|
|
140
206
|
### Semaphores:
|
|
141
207
|
local_sems, # (2, 5): 2 x [b_gating_sem, b_w1_sem, b_w2_sem, b_w3_sem, b_output_sem]
|
|
@@ -145,7 +211,10 @@ def _fused_ep_moe_kernel(
|
|
|
145
211
|
a2a_acc_sem,
|
|
146
212
|
*,
|
|
147
213
|
top_k: int,
|
|
214
|
+
renormalize_topk_logits: bool,
|
|
148
215
|
ep_axis_name: str,
|
|
216
|
+
act_fn: str,
|
|
217
|
+
subc_quant_wsz: int | None = None,
|
|
149
218
|
# Kernel tuning params.
|
|
150
219
|
bt: int, # Block size of local_num_tokens.
|
|
151
220
|
bf: int, # Block size of intermediate_size.
|
|
@@ -160,34 +229,58 @@ def _fused_ep_moe_kernel(
|
|
|
160
229
|
num_devices = lax.axis_size(ep_axis_name)
|
|
161
230
|
local_num_tokens = tokens_hbm.shape[0]
|
|
162
231
|
local_num_experts, intermediate_size, hidden_size = w2_hbm.shape
|
|
163
|
-
# num_experts = local_num_experts * num_devices
|
|
164
|
-
# padded_num_experts = expert_starts_x2_smem.shape[-1]
|
|
165
232
|
right_id = (my_id + 1) % num_devices
|
|
233
|
+
num_experts = a2a_g_hbm.shape[0]
|
|
234
|
+
padded_num_experts = d2e_count_x2_smem.shape[-1]
|
|
235
|
+
padded_top_k = t2e_routing_x2_smem.shape[-1]
|
|
236
|
+
assert padded_num_experts == align_to(num_experts, 128)
|
|
237
|
+
assert padded_top_k == align_to(top_k, 128)
|
|
166
238
|
|
|
167
239
|
t_dtype = tokens_hbm.dtype
|
|
168
240
|
t_packing = get_dtype_packing(t_dtype)
|
|
169
241
|
t_bitwidth = 32 // t_packing
|
|
170
242
|
assert a2a_g_hbm.dtype == t_dtype
|
|
171
|
-
assert w1_hbm.dtype ==
|
|
172
|
-
assert w2_hbm.dtype == t_dtype
|
|
243
|
+
assert w1_hbm.dtype == w2_hbm.dtype
|
|
173
244
|
|
|
174
|
-
|
|
175
|
-
assert
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
245
|
+
assert bd1 % bd1c == 0
|
|
246
|
+
assert bd2 % bd2c == 0
|
|
247
|
+
assert bf % bfc == 0
|
|
248
|
+
assert hidden_size % t_packing == 0
|
|
249
|
+
assert bd1 % t_packing == 0
|
|
250
|
+
assert bd2 % t_packing == 0
|
|
251
|
+
assert bd1c % t_packing == 0
|
|
252
|
+
assert bd2c % t_packing == 0
|
|
253
|
+
|
|
254
|
+
h_per_t_packing = hidden_size // t_packing
|
|
255
|
+
assert tokens_hbm.shape[-1] == h_per_t_packing
|
|
256
|
+
bd1_per_t_packing = bd1 // t_packing
|
|
257
|
+
bd2_per_t_packing = bd2 // t_packing
|
|
258
|
+
bd1c_per_t_packing = bd1c // t_packing
|
|
259
|
+
bd2c_per_t_packing = bd2c // t_packing
|
|
260
|
+
|
|
261
|
+
if subc_quant_wsz is not None:
|
|
262
|
+
assert subc_quant_wsz % 256 == 0
|
|
263
|
+
assert bd1c_per_t_packing == subc_quant_wsz
|
|
264
|
+
assert bfc == subc_quant_wsz
|
|
265
|
+
assert bd1 % subc_quant_wsz == 0
|
|
266
|
+
assert bf % subc_quant_wsz == 0
|
|
267
|
+
assert bd1_per_t_packing % subc_quant_wsz == 0
|
|
268
|
+
assert h_per_t_packing % subc_quant_wsz == 0
|
|
180
269
|
|
|
181
270
|
num_bt = cdiv(local_num_tokens, bt)
|
|
182
271
|
num_bf = cdiv(intermediate_size, bf)
|
|
183
272
|
num_bd1 = cdiv(hidden_size, bd1)
|
|
184
273
|
num_bd2 = cdiv(hidden_size, bd2)
|
|
185
274
|
|
|
275
|
+
def get_mesh_device_id(ep_rank):
|
|
276
|
+
dp_rank = jax.lax.axis_index("data")
|
|
277
|
+
return (dp_rank, ep_rank)
|
|
278
|
+
|
|
186
279
|
def sync_barrier():
|
|
187
280
|
barrier_sem = pltpu.get_barrier_semaphore()
|
|
188
281
|
pltpu.semaphore_signal(
|
|
189
282
|
barrier_sem,
|
|
190
|
-
device_id=(
|
|
283
|
+
device_id=get_mesh_device_id(right_id),
|
|
191
284
|
device_id_type=pltpu.DeviceIdType.MESH,
|
|
192
285
|
)
|
|
193
286
|
pltpu.semaphore_wait(barrier_sem, 1)
|
|
@@ -212,30 +305,44 @@ def _fused_ep_moe_kernel(
|
|
|
212
305
|
sem=b_gating_sem,
|
|
213
306
|
).wait()
|
|
214
307
|
|
|
215
|
-
def get_top_k(input, top_k):
|
|
308
|
+
def get_top_k(input, top_k, renormalize_topk_logits):
|
|
216
309
|
assert len(input.shape) == 2, input.shape
|
|
217
310
|
input = input.astype(jnp.float32)
|
|
311
|
+
padded_k_shape = (input.shape[0], padded_top_k)
|
|
218
312
|
top_k_logits_lst = []
|
|
219
313
|
top_k_indices_lst = []
|
|
220
314
|
t2e = jnp.zeros(input.shape, dtype=jnp.int32)
|
|
221
|
-
t2e_routing = jnp.zeros(
|
|
315
|
+
t2e_routing = jnp.zeros(padded_k_shape, dtype=jnp.int32)
|
|
222
316
|
iota = jax.lax.broadcasted_iota(jnp.int32, input.shape, 1)
|
|
317
|
+
padded_k_iota = jax.lax.broadcasted_iota(jnp.int32, padded_k_shape, 1)
|
|
318
|
+
top_k_logits_sum = jnp.zeros(padded_k_shape, jnp.float32)
|
|
319
|
+
|
|
223
320
|
for k_id in range(top_k):
|
|
224
|
-
# TODO(jevinjiang): return both top_k values and indices in
|
|
321
|
+
# TODO(jevinjiang): return both top_k values and indices in Mosaic
|
|
225
322
|
top_k_logits = jnp.broadcast_to(
|
|
226
|
-
jnp.max(input, axis=1, keepdims=True),
|
|
227
|
-
|
|
323
|
+
jnp.max(input[:, :num_experts], axis=1, keepdims=True),
|
|
324
|
+
padded_k_shape,
|
|
325
|
+
).astype(input.dtype)
|
|
228
326
|
top_k_logits_lst.append(top_k_logits)
|
|
327
|
+
if renormalize_topk_logits:
|
|
328
|
+
top_k_logits_sum += top_k_logits
|
|
229
329
|
# TODO(jevinjiang): support bf16 argmax in Mosaic
|
|
230
330
|
top_k_indices = jnp.broadcast_to(
|
|
231
|
-
jnp.argmax(input, axis=1, keepdims=True),
|
|
331
|
+
jnp.argmax(input[:, :num_experts], axis=1, keepdims=True),
|
|
332
|
+
padded_k_shape,
|
|
333
|
+
)
|
|
232
334
|
top_k_indices_lst.append(top_k_indices)
|
|
233
|
-
t2e_routing = jnp.where(
|
|
234
|
-
|
|
335
|
+
t2e_routing = jnp.where(padded_k_iota == k_id, top_k_indices,
|
|
336
|
+
t2e_routing)
|
|
337
|
+
mask = iota == broadcast_minor(top_k_indices, input.shape)
|
|
235
338
|
t2e += mask.astype(jnp.int32)
|
|
236
339
|
if k_id != top_k - 1:
|
|
237
340
|
input = jnp.where(mask, -jnp.inf, input)
|
|
238
341
|
|
|
342
|
+
if renormalize_topk_logits:
|
|
343
|
+
for k_id in range(top_k):
|
|
344
|
+
top_k_logits_lst[k_id] /= top_k_logits_sum
|
|
345
|
+
|
|
239
346
|
expert_sizes = jnp.sum(t2e, axis=0, keepdims=True)
|
|
240
347
|
expert_starts = jnp.zeros_like(expert_sizes)
|
|
241
348
|
return top_k_logits_lst, t2e_routing, expert_sizes, expert_starts
|
|
@@ -277,7 +384,7 @@ def _fused_ep_moe_kernel(
|
|
|
277
384
|
dst_ref=d2e_count_vmem.at[row_id],
|
|
278
385
|
send_sem=send_sem,
|
|
279
386
|
recv_sem=recv_sem,
|
|
280
|
-
device_id=(
|
|
387
|
+
device_id=get_mesh_device_id(right_id),
|
|
281
388
|
device_id_type=pltpu.DeviceIdType.MESH,
|
|
282
389
|
).wait()
|
|
283
390
|
row_id = (row_id + num_devices - 1) % num_devices
|
|
@@ -359,10 +466,8 @@ def _fused_ep_moe_kernel(
|
|
|
359
466
|
pl.ds(start, remote_sz)],
|
|
360
467
|
send_sem=send_sems.at[e_sem_id],
|
|
361
468
|
recv_sem=recv_sems.at[e_sem_id],
|
|
362
|
-
device_id=(
|
|
363
|
-
|
|
364
|
-
recv_id,
|
|
365
|
-
),
|
|
469
|
+
device_id=get_mesh_device_id(recv_id),
|
|
470
|
+
device_id_type=pltpu.DeviceIdType.MESH,
|
|
366
471
|
).start()
|
|
367
472
|
a2a_s_sends_x2_smem[e_sem_id] = send_sz
|
|
368
473
|
|
|
@@ -406,7 +511,8 @@ def _fused_ep_moe_kernel(
|
|
|
406
511
|
dst_ref=a2a_g_hbm.at[my_e_id, pl.ds(0, remote_sz)],
|
|
407
512
|
send_sem=send_sems.at[e_sem_id],
|
|
408
513
|
recv_sem=a2a_gather_sem,
|
|
409
|
-
device_id=(
|
|
514
|
+
device_id=get_mesh_device_id(recv_id),
|
|
515
|
+
device_id_type=pltpu.DeviceIdType.MESH,
|
|
410
516
|
).start()
|
|
411
517
|
start += sz
|
|
412
518
|
|
|
@@ -435,68 +541,173 @@ def _fused_ep_moe_kernel(
|
|
|
435
541
|
|
|
436
542
|
def start_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id):
|
|
437
543
|
for p in range(t_packing):
|
|
438
|
-
offset = p *
|
|
544
|
+
offset = p * h_per_t_packing + bd1_id * bd1_per_t_packing
|
|
439
545
|
pltpu.make_async_copy(
|
|
440
546
|
src_ref=w1_hbm.at[
|
|
441
547
|
local_e_id,
|
|
442
548
|
0,
|
|
443
|
-
pl.ds(offset,
|
|
549
|
+
pl.ds(offset, bd1_per_t_packing),
|
|
444
550
|
pl.ds(bf_id * bf, bf),
|
|
445
551
|
],
|
|
446
552
|
dst_ref=b_w1_x2_vmem.at[bw1_sem_id, p],
|
|
447
553
|
sem=local_sems.at[bw1_sem_id, 1],
|
|
448
554
|
).start()
|
|
555
|
+
if w1_scale_hbm is not None:
|
|
556
|
+
assert subc_quant_wsz is not None
|
|
557
|
+
pltpu.make_async_copy(
|
|
558
|
+
src_ref=w1_scale_hbm.at[
|
|
559
|
+
local_e_id,
|
|
560
|
+
0,
|
|
561
|
+
pl.ds(
|
|
562
|
+
offset // subc_quant_wsz,
|
|
563
|
+
bd1_per_t_packing // subc_quant_wsz,
|
|
564
|
+
),
|
|
565
|
+
pl.ds(0, 1),
|
|
566
|
+
pl.ds(bf_id * bf, bf),
|
|
567
|
+
],
|
|
568
|
+
dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id, p],
|
|
569
|
+
sem=local_sems.at[bw1_sem_id, 1],
|
|
570
|
+
).start()
|
|
571
|
+
if b1_hbm is not None and bd1_id == 0:
|
|
572
|
+
pltpu.make_async_copy(
|
|
573
|
+
src_ref=b1_hbm.at[local_e_id, 0,
|
|
574
|
+
pl.ds(0, 1),
|
|
575
|
+
pl.ds(bf_id * bf, bf)],
|
|
576
|
+
dst_ref=b_b1_x2_vmem.at[bf_id % 2],
|
|
577
|
+
sem=local_sems.at[bw1_sem_id, 1],
|
|
578
|
+
).start()
|
|
449
579
|
|
|
450
580
|
def start_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id):
|
|
451
581
|
for p in range(t_packing):
|
|
452
|
-
offset = p *
|
|
582
|
+
offset = p * h_per_t_packing + bd2_id * bd2_per_t_packing
|
|
453
583
|
pltpu.make_async_copy(
|
|
454
584
|
src_ref=w2_hbm.at[
|
|
455
585
|
local_e_id,
|
|
456
586
|
pl.ds(bf_id * bf, bf),
|
|
457
|
-
pl.ds(offset,
|
|
587
|
+
pl.ds(offset, bd2_per_t_packing),
|
|
458
588
|
],
|
|
459
589
|
dst_ref=b_w2_x2_vmem.at[bw2_sem_id, p],
|
|
460
590
|
sem=local_sems.at[bw2_sem_id, 2],
|
|
461
591
|
).start()
|
|
592
|
+
if w2_scale_hbm is not None:
|
|
593
|
+
assert subc_quant_wsz is not None
|
|
594
|
+
pltpu.make_async_copy(
|
|
595
|
+
src_ref=w2_scale_hbm.at[
|
|
596
|
+
local_e_id,
|
|
597
|
+
pl.ds(bf_id * bf // subc_quant_wsz, bf //
|
|
598
|
+
subc_quant_wsz),
|
|
599
|
+
pl.ds(0, 1),
|
|
600
|
+
pl.ds(offset, bd2_per_t_packing),
|
|
601
|
+
],
|
|
602
|
+
dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id, p],
|
|
603
|
+
sem=local_sems.at[bw2_sem_id, 2],
|
|
604
|
+
).start()
|
|
605
|
+
if b2_hbm is not None and bf_id == 0:
|
|
606
|
+
pltpu.make_async_copy(
|
|
607
|
+
src_ref=b2_hbm.at[local_e_id,
|
|
608
|
+
pl.ds(0, 1),
|
|
609
|
+
pl.ds(offset, bd2_per_t_packing)],
|
|
610
|
+
dst_ref=b_b2_x2_vmem.at[bd2_id % 2, p],
|
|
611
|
+
sem=local_sems.at[bw2_sem_id, 2],
|
|
612
|
+
).start()
|
|
462
613
|
|
|
463
614
|
def start_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id):
|
|
464
615
|
for p in range(t_packing):
|
|
465
|
-
offset = p *
|
|
616
|
+
offset = p * h_per_t_packing + bd3_id * bd1_per_t_packing
|
|
466
617
|
pltpu.make_async_copy(
|
|
467
618
|
src_ref=w1_hbm.at[
|
|
468
619
|
local_e_id,
|
|
469
620
|
1,
|
|
470
|
-
pl.ds(offset,
|
|
621
|
+
pl.ds(offset, bd1_per_t_packing),
|
|
471
622
|
pl.ds(bf_id * bf, bf),
|
|
472
623
|
],
|
|
473
624
|
dst_ref=b_w3_x2_vmem.at[bw3_sem_id, p],
|
|
474
625
|
sem=local_sems.at[bw3_sem_id, 3],
|
|
475
626
|
).start()
|
|
627
|
+
if w1_scale_hbm is not None:
|
|
628
|
+
assert subc_quant_wsz is not None
|
|
629
|
+
pltpu.make_async_copy(
|
|
630
|
+
src_ref=w1_scale_hbm.at[
|
|
631
|
+
local_e_id,
|
|
632
|
+
1,
|
|
633
|
+
pl.ds(
|
|
634
|
+
offset // subc_quant_wsz,
|
|
635
|
+
bd1_per_t_packing // subc_quant_wsz,
|
|
636
|
+
),
|
|
637
|
+
pl.ds(0, 1),
|
|
638
|
+
pl.ds(bf_id * bf, bf),
|
|
639
|
+
],
|
|
640
|
+
dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id, p],
|
|
641
|
+
sem=local_sems.at[bw3_sem_id, 3],
|
|
642
|
+
).start()
|
|
643
|
+
if b1_hbm is not None and bd3_id == 0:
|
|
644
|
+
pltpu.make_async_copy(
|
|
645
|
+
src_ref=b1_hbm.at[local_e_id, 1,
|
|
646
|
+
pl.ds(0, 1),
|
|
647
|
+
pl.ds(bf_id * bf, bf)],
|
|
648
|
+
dst_ref=b_b3_x2_vmem.at[bf_id % 2],
|
|
649
|
+
sem=local_sems.at[bw3_sem_id, 3],
|
|
650
|
+
).start()
|
|
476
651
|
|
|
477
652
|
def wait_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id):
|
|
478
|
-
del local_e_id
|
|
653
|
+
del local_e_id
|
|
479
654
|
pltpu.make_async_copy(
|
|
480
655
|
src_ref=b_w1_x2_vmem.at[bw1_sem_id],
|
|
481
656
|
dst_ref=b_w1_x2_vmem.at[bw1_sem_id],
|
|
482
657
|
sem=local_sems.at[bw1_sem_id, 1],
|
|
483
658
|
).wait()
|
|
659
|
+
if w1_scale_hbm is not None:
|
|
660
|
+
pltpu.make_async_copy(
|
|
661
|
+
src_ref=b_w1_scale_x2_vmem.at[bw1_sem_id],
|
|
662
|
+
dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id],
|
|
663
|
+
sem=local_sems.at[bw1_sem_id, 1],
|
|
664
|
+
).wait()
|
|
665
|
+
if b1_hbm is not None and bd1_id == 0:
|
|
666
|
+
pltpu.make_async_copy(
|
|
667
|
+
src_ref=b_b1_x2_vmem.at[bf_id % 2],
|
|
668
|
+
dst_ref=b_b1_x2_vmem.at[bf_id % 2],
|
|
669
|
+
sem=local_sems.at[bw1_sem_id, 1],
|
|
670
|
+
).wait()
|
|
484
671
|
|
|
485
672
|
def wait_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id):
|
|
486
|
-
del local_e_id
|
|
673
|
+
del local_e_id
|
|
487
674
|
pltpu.make_async_copy(
|
|
488
675
|
src_ref=b_w2_x2_vmem.at[bw2_sem_id],
|
|
489
676
|
dst_ref=b_w2_x2_vmem.at[bw2_sem_id],
|
|
490
677
|
sem=local_sems.at[bw2_sem_id, 2],
|
|
491
678
|
).wait()
|
|
679
|
+
if w2_scale_hbm is not None:
|
|
680
|
+
pltpu.make_async_copy(
|
|
681
|
+
src_ref=b_w2_scale_x2_vmem.at[bw2_sem_id],
|
|
682
|
+
dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id],
|
|
683
|
+
sem=local_sems.at[bw2_sem_id, 2],
|
|
684
|
+
).wait()
|
|
685
|
+
if b2_hbm is not None and bf_id == 0:
|
|
686
|
+
pltpu.make_async_copy(
|
|
687
|
+
src_ref=b_b2_x2_vmem.at[bd2_id % 2],
|
|
688
|
+
dst_ref=b_b2_x2_vmem.at[bd2_id % 2],
|
|
689
|
+
sem=local_sems.at[bw2_sem_id, 2],
|
|
690
|
+
).wait()
|
|
492
691
|
|
|
493
692
|
def wait_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id):
|
|
494
|
-
del local_e_id
|
|
693
|
+
del local_e_id
|
|
495
694
|
pltpu.make_async_copy(
|
|
496
695
|
src_ref=b_w3_x2_vmem.at[bw3_sem_id],
|
|
497
696
|
dst_ref=b_w3_x2_vmem.at[bw3_sem_id],
|
|
498
697
|
sem=local_sems.at[bw3_sem_id, 3],
|
|
499
698
|
).wait()
|
|
699
|
+
if w1_scale_hbm is not None:
|
|
700
|
+
pltpu.make_async_copy(
|
|
701
|
+
src_ref=b_w3_scale_x2_vmem.at[bw3_sem_id],
|
|
702
|
+
dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id],
|
|
703
|
+
sem=local_sems.at[bw3_sem_id, 3],
|
|
704
|
+
).wait()
|
|
705
|
+
if b1_hbm is not None and bd3_id == 0:
|
|
706
|
+
pltpu.make_async_copy(
|
|
707
|
+
src_ref=b_b3_x2_vmem.at[bf_id % 2],
|
|
708
|
+
dst_ref=b_b3_x2_vmem.at[bf_id % 2],
|
|
709
|
+
sem=local_sems.at[bw3_sem_id, 3],
|
|
710
|
+
).wait()
|
|
500
711
|
|
|
501
712
|
def start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, bd2_id):
|
|
502
713
|
next_bd1_id = bd1_id + 1
|
|
@@ -520,18 +731,38 @@ def _fused_ep_moe_kernel(
|
|
|
520
731
|
def dynamic_ffn1(
|
|
521
732
|
t_b32_vmem,
|
|
522
733
|
w1_vmem,
|
|
734
|
+
w1_scale_vmem,
|
|
735
|
+
b1_vmem,
|
|
523
736
|
w3_vmem,
|
|
737
|
+
w3_scale_vmem,
|
|
738
|
+
b3_vmem,
|
|
524
739
|
acc1_vmem,
|
|
525
740
|
acc3_vmem,
|
|
526
741
|
dyn_sz,
|
|
527
742
|
should_init,
|
|
528
743
|
):
|
|
529
744
|
assert t_b32_vmem.shape == (bt * num_devices, bd1 // t_packing)
|
|
530
|
-
assert w1_vmem.shape == w3_vmem.shape == (t_packing,
|
|
745
|
+
assert w1_vmem.shape == w3_vmem.shape == (t_packing, bd1_per_t_packing,
|
|
531
746
|
bf)
|
|
532
747
|
assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf)
|
|
533
748
|
assert bd1 % (t_packing * 128) == 0, (bd1, t_packing)
|
|
534
749
|
assert bd1c % (t_packing * 128) == 0, (bd1c, t_packing)
|
|
750
|
+
if w1_scale_vmem is not None:
|
|
751
|
+
assert w1_scale_vmem.shape == (
|
|
752
|
+
t_packing,
|
|
753
|
+
bd1_per_t_packing // subc_quant_wsz,
|
|
754
|
+
1,
|
|
755
|
+
bf,
|
|
756
|
+
)
|
|
757
|
+
assert bd1c_per_t_packing == subc_quant_wsz
|
|
758
|
+
if w3_scale_vmem is not None:
|
|
759
|
+
assert w3_scale_vmem.shape == (
|
|
760
|
+
t_packing,
|
|
761
|
+
bd1_per_t_packing // subc_quant_wsz,
|
|
762
|
+
1,
|
|
763
|
+
bf,
|
|
764
|
+
)
|
|
765
|
+
assert bd1c_per_t_packing == subc_quant_wsz
|
|
535
766
|
|
|
536
767
|
num_loops = cdiv(dyn_sz, btc)
|
|
537
768
|
repack_ty = jnp.dtype(f"int{t_bitwidth}")
|
|
@@ -540,7 +771,7 @@ def _fused_ep_moe_kernel(
|
|
|
540
771
|
for bd1c_id in range(cdiv(bd1, bd1c)):
|
|
541
772
|
t_b32 = t_b32_vmem[
|
|
542
773
|
pl.ds(btc_id * btc, btc),
|
|
543
|
-
pl.ds(bd1c_id *
|
|
774
|
+
pl.ds(bd1c_id * bd1c_per_t_packing, bd1c_per_t_packing),
|
|
544
775
|
]
|
|
545
776
|
for p_id in range(t_packing):
|
|
546
777
|
t = pltpu.bitcast(t_b32.astype(repack_ty), t_dtype)
|
|
@@ -548,21 +779,64 @@ def _fused_ep_moe_kernel(
|
|
|
548
779
|
for bfc_id in range(cdiv(bf, bfc)):
|
|
549
780
|
w_slices = (
|
|
550
781
|
p_id,
|
|
551
|
-
pl.ds(bd1c_id *
|
|
552
|
-
|
|
782
|
+
pl.ds(bd1c_id * bd1c_per_t_packing,
|
|
783
|
+
bd1c_per_t_packing),
|
|
553
784
|
pl.ds(bfc_id * bfc, bfc),
|
|
554
785
|
)
|
|
555
786
|
w1 = w1_vmem[*w_slices]
|
|
556
787
|
acc1 = jnp.dot(t,
|
|
557
788
|
w1,
|
|
558
789
|
preferred_element_type=jnp.float32)
|
|
790
|
+
|
|
791
|
+
if w1_scale_vmem is not None:
|
|
792
|
+
w1_scale_slices = (
|
|
793
|
+
p_id,
|
|
794
|
+
bd1c_id,
|
|
795
|
+
pl.ds(0, 1),
|
|
796
|
+
pl.ds(bfc_id * bfc, bfc),
|
|
797
|
+
)
|
|
798
|
+
# TODO(jevinjiang): can use mosaic to load with stride 0.
|
|
799
|
+
w1_scale = jnp.broadcast_to(
|
|
800
|
+
w1_scale_vmem[*w1_scale_slices], acc1.shape)
|
|
801
|
+
acc1 *= w1_scale
|
|
802
|
+
|
|
559
803
|
w3 = w3_vmem[*w_slices]
|
|
804
|
+
|
|
560
805
|
acc3 = jnp.dot(t,
|
|
561
806
|
w3,
|
|
562
807
|
preferred_element_type=jnp.float32)
|
|
808
|
+
|
|
809
|
+
if w3_scale_vmem is not None:
|
|
810
|
+
w3_scale_slices = (
|
|
811
|
+
p_id,
|
|
812
|
+
bd1c_id,
|
|
813
|
+
pl.ds(0, 1),
|
|
814
|
+
pl.ds(bfc_id * bfc, bfc),
|
|
815
|
+
)
|
|
816
|
+
w3_scale = jnp.broadcast_to(
|
|
817
|
+
w3_scale_vmem[*w3_scale_slices], acc3.shape)
|
|
818
|
+
acc3 *= w3_scale
|
|
819
|
+
|
|
563
820
|
acc_slices = (pl.ds(btc_id * btc,
|
|
564
821
|
btc), pl.ds(bfc_id * bfc, bfc))
|
|
565
822
|
if should_init and p_id == bd1c_id == 0:
|
|
823
|
+
if b1_vmem is not None:
|
|
824
|
+
b1_scale_slices = (
|
|
825
|
+
pl.ds(0, 1),
|
|
826
|
+
pl.ds(bfc_id * bfc, bfc),
|
|
827
|
+
)
|
|
828
|
+
b1 = jnp.broadcast_to(
|
|
829
|
+
b1_vmem[*b1_scale_slices], acc1.shape)
|
|
830
|
+
acc1 += b1
|
|
831
|
+
if b3_vmem is not None:
|
|
832
|
+
b3_scale_slices = (
|
|
833
|
+
pl.ds(0, 1),
|
|
834
|
+
pl.ds(bfc_id * bfc, bfc),
|
|
835
|
+
)
|
|
836
|
+
b3 = jnp.broadcast_to(
|
|
837
|
+
b3_vmem[*b3_scale_slices], acc1.shape)
|
|
838
|
+
acc3 += b3
|
|
839
|
+
|
|
566
840
|
acc1_vmem[*acc_slices] = acc1
|
|
567
841
|
acc3_vmem[*acc_slices] = acc3
|
|
568
842
|
else:
|
|
@@ -575,22 +849,28 @@ def _fused_ep_moe_kernel(
|
|
|
575
849
|
acc1_vmem,
|
|
576
850
|
acc3_vmem,
|
|
577
851
|
w2_vmem,
|
|
852
|
+
w2_scale_vmem,
|
|
853
|
+
b2_vmem,
|
|
578
854
|
res_b32_vmem,
|
|
579
855
|
dyn_sz,
|
|
580
856
|
should_init,
|
|
581
857
|
):
|
|
582
|
-
assert res_b32_vmem.shape == (bt * num_devices,
|
|
583
|
-
assert w2_vmem.shape == (t_packing, bf,
|
|
584
|
-
w2_vmem.shape,
|
|
585
|
-
t_packing,
|
|
586
|
-
bf,
|
|
587
|
-
bd2_per_packing,
|
|
588
|
-
)
|
|
858
|
+
assert res_b32_vmem.shape == (bt * num_devices, bd2_per_t_packing)
|
|
859
|
+
assert w2_vmem.shape == (t_packing, bf, bd2_per_t_packing)
|
|
589
860
|
assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf)
|
|
590
861
|
assert bd2 % (t_packing * 128) == 0, (bd2, t_packing)
|
|
591
862
|
assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing)
|
|
592
863
|
assert t_dtype in (jnp.float32, jnp.bfloat16)
|
|
593
864
|
|
|
865
|
+
if w2_scale_vmem is not None:
|
|
866
|
+
assert w2_scale_vmem.shape == (
|
|
867
|
+
t_packing,
|
|
868
|
+
bf // subc_quant_wsz,
|
|
869
|
+
1,
|
|
870
|
+
bd2_per_t_packing,
|
|
871
|
+
)
|
|
872
|
+
assert bfc == subc_quant_wsz
|
|
873
|
+
|
|
594
874
|
num_loops = cdiv(dyn_sz, btc)
|
|
595
875
|
assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing)
|
|
596
876
|
|
|
@@ -598,22 +878,47 @@ def _fused_ep_moe_kernel(
|
|
|
598
878
|
for bd2c_id in range(cdiv(bd2, bd2c)):
|
|
599
879
|
res_lst = []
|
|
600
880
|
for p_id in range(t_packing):
|
|
601
|
-
res = jnp.zeros((btc,
|
|
881
|
+
res = jnp.zeros((btc, bd2c_per_t_packing),
|
|
882
|
+
dtype=jnp.float32)
|
|
883
|
+
|
|
884
|
+
if b2_vmem is not None and should_init:
|
|
885
|
+
b2_scale_slices = (
|
|
886
|
+
p_id,
|
|
887
|
+
pl.ds(0, 1),
|
|
888
|
+
pl.ds(bd2c_id * bd2c_per_t_packing,
|
|
889
|
+
bd2c_per_t_packing),
|
|
890
|
+
)
|
|
891
|
+
b2 = jnp.broadcast_to(b2_vmem[*b2_scale_slices],
|
|
892
|
+
res.shape)
|
|
893
|
+
res += b2
|
|
894
|
+
|
|
602
895
|
for bfc_id in range(cdiv(bf, bfc)):
|
|
603
896
|
acc_slices = (pl.ds(btc_id * btc,
|
|
604
897
|
btc), pl.ds(bfc_id * bfc, bfc))
|
|
605
898
|
acc1 = acc1_vmem[*acc_slices]
|
|
606
899
|
acc3 = acc3_vmem[*acc_slices]
|
|
607
|
-
act =
|
|
900
|
+
act = activation_fn(acc1, acc3, act_fn)
|
|
608
901
|
w2 = w2_vmem[
|
|
609
902
|
p_id,
|
|
610
903
|
pl.ds(bfc_id * bfc, bfc),
|
|
611
904
|
pl.ds(bd2c_id *
|
|
612
|
-
|
|
905
|
+
bd2c_per_t_packing, bd2c_per_t_packing),
|
|
613
906
|
]
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
907
|
+
acc = jnp.dot(act,
|
|
908
|
+
w2,
|
|
909
|
+
preferred_element_type=jnp.float32)
|
|
910
|
+
if w2_scale_vmem is not None:
|
|
911
|
+
w2_scale_slices = (
|
|
912
|
+
p_id,
|
|
913
|
+
bfc_id,
|
|
914
|
+
pl.ds(0, 1),
|
|
915
|
+
pl.ds(bd2c_id * bd2c_per_t_packing,
|
|
916
|
+
bd2c_per_t_packing),
|
|
917
|
+
)
|
|
918
|
+
w2_scale = jnp.broadcast_to(
|
|
919
|
+
w2_scale_vmem[*w2_scale_slices], acc.shape)
|
|
920
|
+
acc *= w2_scale
|
|
921
|
+
res += acc
|
|
617
922
|
res = pltpu.bitcast(res, jnp.uint32)
|
|
618
923
|
if t_packing == 2:
|
|
619
924
|
res = res >> 16 << (16 * p_id)
|
|
@@ -626,7 +931,7 @@ def _fused_ep_moe_kernel(
|
|
|
626
931
|
res |= res_lst[i]
|
|
627
932
|
sliced_res_vmem = res_b32_vmem.at[
|
|
628
933
|
pl.ds(btc_id * btc, btc),
|
|
629
|
-
pl.ds(bd2c_id *
|
|
934
|
+
pl.ds(bd2c_id * bd2c_per_t_packing, bd2c_per_t_packing),
|
|
630
935
|
]
|
|
631
936
|
if should_init:
|
|
632
937
|
sliced_res_vmem[...] = res
|
|
@@ -655,21 +960,33 @@ def _fused_ep_moe_kernel(
|
|
|
655
960
|
e_id = my_id * local_num_experts + local_e_id
|
|
656
961
|
dyn_sz = expert_sizes_x2_smem[bt_sem_id, 0, e_id]
|
|
657
962
|
|
|
658
|
-
|
|
659
|
-
|
|
963
|
+
bd1_per_t_packing = bd1 // t_packing
|
|
964
|
+
bd2_per_t_packing = bd2 // t_packing
|
|
660
965
|
|
|
661
966
|
for bf_id in range(num_bf):
|
|
662
967
|
for bd1_id in range(num_bd1):
|
|
663
968
|
start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, 0)
|
|
969
|
+
w1_scale_vmem = (None if b_w1_scale_x2_vmem is None else
|
|
970
|
+
b_w1_scale_x2_vmem.at[bw_sem_id])
|
|
971
|
+
w3_scale_vmem = (None if b_w3_scale_x2_vmem is None else
|
|
972
|
+
b_w3_scale_x2_vmem.at[bw_sem_id])
|
|
973
|
+
b1_vmem = None if b_b1_x2_vmem is None else b_b1_x2_vmem.at[
|
|
974
|
+
bf_id % 2]
|
|
975
|
+
b3_vmem = None if b_b3_x2_vmem is None else b_b3_x2_vmem.at[
|
|
976
|
+
bf_id % 2]
|
|
664
977
|
wait_fetch_bw1(local_e_id, bw_sem_id, bf_id, bd1_id)
|
|
665
978
|
wait_fetch_bw3(local_e_id, bw_sem_id, bf_id, bd1_id)
|
|
666
979
|
|
|
667
980
|
dynamic_ffn1(
|
|
668
981
|
t_b32_vmem=a2a_s_b32_vmem.at[
|
|
669
982
|
...,
|
|
670
|
-
pl.ds(bd1_id *
|
|
983
|
+
pl.ds(bd1_id * bd1_per_t_packing, bd1_per_t_packing)],
|
|
671
984
|
w1_vmem=b_w1_x2_vmem.at[bw_sem_id],
|
|
985
|
+
w1_scale_vmem=w1_scale_vmem,
|
|
986
|
+
b1_vmem=b1_vmem,
|
|
672
987
|
w3_vmem=b_w3_x2_vmem.at[bw_sem_id],
|
|
988
|
+
w3_scale_vmem=w3_scale_vmem,
|
|
989
|
+
b3_vmem=b3_vmem,
|
|
673
990
|
acc1_vmem=b_acc1_vmem,
|
|
674
991
|
acc3_vmem=b_acc3_vmem,
|
|
675
992
|
dyn_sz=dyn_sz,
|
|
@@ -684,13 +1001,19 @@ def _fused_ep_moe_kernel(
|
|
|
684
1001
|
if bf_id == bd2_id == 0:
|
|
685
1002
|
wait_a2a_gather_send(bt_id, e_sem_id, local_e_id - 2)
|
|
686
1003
|
|
|
1004
|
+
w2_scale_vmem = (None if b_w2_scale_x2_vmem is None else
|
|
1005
|
+
b_w2_scale_x2_vmem.at[bw_sem_id])
|
|
1006
|
+
b2_vmem = None if b_b2_x2_vmem is None else b_b2_x2_vmem.at[
|
|
1007
|
+
bd2_id % 2]
|
|
687
1008
|
dynamic_ffn2(
|
|
688
1009
|
acc1_vmem=b_acc1_vmem,
|
|
689
1010
|
acc3_vmem=b_acc3_vmem,
|
|
690
1011
|
w2_vmem=b_w2_x2_vmem.at[bw_sem_id],
|
|
1012
|
+
w2_scale_vmem=w2_scale_vmem,
|
|
1013
|
+
b2_vmem=b2_vmem,
|
|
691
1014
|
res_b32_vmem=a2a_s_acc_b32_vmem.at[
|
|
692
1015
|
...,
|
|
693
|
-
pl.ds(bd2_id *
|
|
1016
|
+
pl.ds(bd2_id * bd2_per_t_packing, bd2_per_t_packing)],
|
|
694
1017
|
dyn_sz=dyn_sz,
|
|
695
1018
|
should_init=(bf_id == 0),
|
|
696
1019
|
)
|
|
@@ -757,31 +1080,42 @@ def _fused_ep_moe_kernel(
|
|
|
757
1080
|
b_gating = b_gating_x2_vmem[bt_sem_id]
|
|
758
1081
|
b_gating_score = jax.nn.softmax(b_gating, axis=-1)
|
|
759
1082
|
top_k_logits_lst, t2e_routing, expert_sizes, expert_starts = get_top_k(
|
|
760
|
-
b_gating_score, top_k)
|
|
1083
|
+
b_gating_score, top_k, renormalize_topk_logits)
|
|
761
1084
|
|
|
762
1085
|
all_reduce_metadata(bt_sem_id, t2e_routing, expert_starts,
|
|
763
1086
|
expert_sizes)
|
|
1087
|
+
sync_barrier()
|
|
764
1088
|
|
|
1089
|
+
# Start a2a scatter for first active expert.
|
|
765
1090
|
start_a2a_scatter(bt_id=bt_id, e_sem_id=e_sem_id, local_e_id=0)
|
|
766
1091
|
|
|
767
1092
|
def run_per_expert(local_e_id, e_sem_id):
|
|
768
1093
|
sync_barrier()
|
|
1094
|
+
|
|
1095
|
+
# Prefetch weights for CURRENT active expert.
|
|
1096
|
+
# TODO(jevinjiang): It is hard to prefetch weights in previous iteration
|
|
1097
|
+
# because the expert_ffn keeps overwriting the buffers. Triple buffering
|
|
1098
|
+
# could resolve this but it takes more VMEM scratch. Need further
|
|
1099
|
+
# experiment on this.
|
|
1100
|
+
start_fetch_bw1(local_e_id, bw1_sem_id=0, bf_id=0, bd1_id=0)
|
|
1101
|
+
start_fetch_bw3(local_e_id, bw3_sem_id=0, bf_id=0, bd3_id=0)
|
|
1102
|
+
|
|
1103
|
+
# Next ids.
|
|
769
1104
|
next_e_sem_id = lax.select(e_sem_id == 0, 1, 0)
|
|
770
1105
|
next_local_e_id = local_e_id + 1
|
|
771
1106
|
|
|
1107
|
+
# Start a2a scatter for NEXT active expert.
|
|
772
1108
|
@pl.when(next_local_e_id < local_num_experts)
|
|
773
1109
|
def _():
|
|
774
1110
|
start_a2a_scatter(bt_id, next_e_sem_id, next_local_e_id)
|
|
775
1111
|
|
|
776
|
-
#
|
|
777
|
-
start_fetch_bw1(local_e_id, bw1_sem_id=0, bf_id=0, bd1_id=0)
|
|
778
|
-
start_fetch_bw3(local_e_id, bw3_sem_id=0, bf_id=0, bd3_id=0)
|
|
779
|
-
|
|
780
|
-
# Wait for a2a scatter and perform FFN for active expert.
|
|
1112
|
+
# Wait a2a scatter for CURRENT active expert.
|
|
781
1113
|
wait_a2a_scatter_recv(bt_id, e_sem_id, local_e_id)
|
|
1114
|
+
|
|
1115
|
+
# Perform FFN for CURRENT active expert.
|
|
782
1116
|
expert_ffn(bt_id, e_sem_id, local_e_id)
|
|
783
1117
|
|
|
784
|
-
#
|
|
1118
|
+
# Start a2a gather to send back tokens for CURRENT active expert.
|
|
785
1119
|
start_a2a_gather(bt_id, e_sem_id, local_e_id)
|
|
786
1120
|
|
|
787
1121
|
# A must-wait before next sync_barrier.
|
|
@@ -794,7 +1128,10 @@ def _fused_ep_moe_kernel(
|
|
|
794
1128
|
e_sem_id,
|
|
795
1129
|
unroll=False)
|
|
796
1130
|
|
|
1131
|
+
# Wait to receive a2a gather for ALL experts.
|
|
797
1132
|
wait_a2a_gather_recv_all()
|
|
1133
|
+
|
|
1134
|
+
# Accumulate results for current batch.
|
|
798
1135
|
output = bt_acc(bt_id, top_k_logits_lst)
|
|
799
1136
|
|
|
800
1137
|
# Make sure it is safe to overwrite output buffer.
|
|
@@ -827,6 +1164,9 @@ def _fused_ep_moe_kernel(
|
|
|
827
1164
|
static_argnames=[
|
|
828
1165
|
"mesh",
|
|
829
1166
|
"top_k",
|
|
1167
|
+
"renormalize_topk_logits",
|
|
1168
|
+
"act_fn",
|
|
1169
|
+
"subc_quant_wsz",
|
|
830
1170
|
"bt",
|
|
831
1171
|
"bf",
|
|
832
1172
|
"bd1",
|
|
@@ -846,6 +1186,17 @@ def fused_ep_moe(
|
|
|
846
1186
|
gating_output: jax.Array, # (num_tokens, num_experts)
|
|
847
1187
|
top_k: int,
|
|
848
1188
|
*,
|
|
1189
|
+
renormalize_topk_logits: bool = False,
|
|
1190
|
+
act_fn: str = "silu",
|
|
1191
|
+
subc_quant_wsz: int | None = None,
|
|
1192
|
+
w1_scale: (
|
|
1193
|
+
jax.Array | None
|
|
1194
|
+
) = None, # F32(num_experts, 2, hidden_size // subc_quant_wsz, 1, intermediate_size)
|
|
1195
|
+
w2_scale: (
|
|
1196
|
+
jax.Array | None
|
|
1197
|
+
) = None, # F32(num_experts, intermediate_size // subc_quant_wsz, 1, hidden_size)
|
|
1198
|
+
b1: jax.Array | None = None, # F32(num_experts, 2, 1, intermediate_size)
|
|
1199
|
+
b2: jax.Array | None = None, # F32(num_experts, 1, hidden_size)
|
|
849
1200
|
# Kernel tuning parameters.
|
|
850
1201
|
bt: int,
|
|
851
1202
|
bf: int,
|
|
@@ -855,52 +1206,164 @@ def fused_ep_moe(
|
|
|
855
1206
|
bfc: int,
|
|
856
1207
|
bd1c: int,
|
|
857
1208
|
bd2c: int,
|
|
858
|
-
ep_axis_name: str =
|
|
1209
|
+
ep_axis_name: str = "model",
|
|
859
1210
|
):
|
|
860
|
-
#
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
1211
|
+
# TODO(jevinjiang): move all these assertions to validation function.
|
|
1212
|
+
if len(mesh.shape) != 2:
|
|
1213
|
+
raise NotImplementedError("Only 2D mesh is supported.")
|
|
1214
|
+
|
|
1215
|
+
for axis_name in mesh.axis_names:
|
|
1216
|
+
if axis_name == ep_axis_name:
|
|
1217
|
+
continue
|
|
1218
|
+
if mesh.shape[axis_name] != 1:
|
|
1219
|
+
raise NotImplementedError(
|
|
1220
|
+
f"Expected all non-ep axis to have size 1 in {mesh.shape=}")
|
|
864
1221
|
|
|
865
1222
|
ep_size = mesh.shape[ep_axis_name]
|
|
866
1223
|
num_devices = ep_size
|
|
867
1224
|
|
|
868
|
-
num_tokens,
|
|
1225
|
+
num_tokens, hidden_size = tokens.shape
|
|
869
1226
|
num_experts, intermediate_size, _ = w2.shape
|
|
870
1227
|
|
|
871
|
-
|
|
872
|
-
|
|
1228
|
+
if w1.shape != (num_experts, 2, hidden_size, intermediate_size):
|
|
1229
|
+
raise ValueError(
|
|
1230
|
+
f"Expected {w1.shape=} to be"
|
|
1231
|
+
f" {(num_experts, 2, hidden_size, intermediate_size)}.")
|
|
1232
|
+
|
|
1233
|
+
if w2.shape != (num_experts, intermediate_size, hidden_size):
|
|
1234
|
+
raise ValueError(f"Expected {w2.shape=} to be"
|
|
1235
|
+
f" {(num_experts, intermediate_size, hidden_size)}.")
|
|
1236
|
+
|
|
1237
|
+
if gating_output.shape != (num_tokens, num_experts):
|
|
1238
|
+
raise ValueError(
|
|
1239
|
+
f"Expected {gating_output.shape=} to be {(num_tokens, num_experts)}."
|
|
1240
|
+
)
|
|
1241
|
+
|
|
1242
|
+
if not (0 < top_k <= num_experts):
|
|
1243
|
+
raise ValueError(
|
|
1244
|
+
f"Expected {top_k=} to be in range (0, {num_experts=}].")
|
|
1245
|
+
|
|
1246
|
+
if hidden_size % 128 != 0 or intermediate_size % 128 != 0:
|
|
1247
|
+
raise ValueError(
|
|
1248
|
+
f"Expected {hidden_size=} and {intermediate_size=} to be aligned to"
|
|
1249
|
+
" 128. Did you pad them with zeros outside the kernel?")
|
|
1250
|
+
if num_tokens % ep_size != 0:
|
|
1251
|
+
raise ValueError(
|
|
1252
|
+
f"Expected {num_tokens=} to be aligned to {ep_size=}.")
|
|
1253
|
+
if num_experts % ep_size != 0:
|
|
1254
|
+
raise ValueError(
|
|
1255
|
+
f"Expected {num_experts=} to be aligned to {ep_size=}.")
|
|
873
1256
|
|
|
874
1257
|
local_num_tokens = num_tokens // ep_size
|
|
875
1258
|
# local_num_experts = num_experts // ep_size
|
|
876
1259
|
padded_num_experts = align_to(num_experts, 128)
|
|
877
|
-
|
|
1260
|
+
padded_top_k = align_to(top_k, 128)
|
|
878
1261
|
t_dtype = tokens.dtype
|
|
879
1262
|
t_packing = get_dtype_packing(t_dtype)
|
|
880
|
-
hidden_size = align_to(actual_hidden_size, 128 * t_packing)
|
|
881
|
-
if hidden_size != actual_hidden_size:
|
|
882
|
-
tokens = jnp.pad(
|
|
883
|
-
tokens,
|
|
884
|
-
((0, 0), (0, hidden_size - actual_hidden_size)),
|
|
885
|
-
constant_values=0,
|
|
886
|
-
)
|
|
887
|
-
tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing)
|
|
888
|
-
bt = min(bt, local_num_tokens)
|
|
889
|
-
bf = min(bf, intermediate_size)
|
|
890
|
-
bd1 = min(bd1, hidden_size)
|
|
891
|
-
bd2 = min(bd2, hidden_size)
|
|
892
|
-
|
|
893
|
-
btc = min(btc, bt * num_devices)
|
|
894
|
-
bfc = min(bfc, bf)
|
|
895
|
-
bd1c = min(bd1c, bd1)
|
|
896
|
-
bd2c = min(bd2c, bd2)
|
|
897
|
-
assert bfc % 128 == 0
|
|
898
|
-
assert bd1c % (t_packing * 128) == 0
|
|
899
|
-
assert bd2c % (t_packing * 128) == 0
|
|
900
|
-
assert bf % bfc == 0
|
|
901
|
-
assert bd1 % bd1c == 0
|
|
902
|
-
assert bd2 % bd2c == 0
|
|
903
1263
|
|
|
1264
|
+
# Override bt
|
|
1265
|
+
if local_num_tokens <= t_packing * 8:
|
|
1266
|
+
bt = local_num_tokens
|
|
1267
|
+
btc = bt
|
|
1268
|
+
bt = min(local_num_tokens, bt)
|
|
1269
|
+
# The worst case is that all devices send bt to one device.
|
|
1270
|
+
btc = min(bt, btc, bt * num_devices)
|
|
1271
|
+
|
|
1272
|
+
if local_num_tokens % t_packing != 0:
|
|
1273
|
+
raise ValueError(
|
|
1274
|
+
f"Expected {local_num_tokens=} to be aligned to {t_packing=}.")
|
|
1275
|
+
|
|
1276
|
+
if bt % t_packing != 0:
|
|
1277
|
+
raise ValueError(f"Expected {bt=} to be aligned to {t_packing=}.")
|
|
1278
|
+
if local_num_tokens % bt != 0:
|
|
1279
|
+
raise ValueError(
|
|
1280
|
+
f"Expected {local_num_tokens=} to be aligned to {bt=}.")
|
|
1281
|
+
|
|
1282
|
+
if subc_quant_wsz is not None:
|
|
1283
|
+
if subc_quant_wsz <= 0:
|
|
1284
|
+
raise ValueError(f"Expected {subc_quant_wsz=} to be non-negative.")
|
|
1285
|
+
if subc_quant_wsz % 256 != 0:
|
|
1286
|
+
raise ValueError(
|
|
1287
|
+
"Expected {subc_quant_wsz=} to be aligned to 256.")
|
|
1288
|
+
if hidden_size % subc_quant_wsz != 0:
|
|
1289
|
+
raise ValueError(
|
|
1290
|
+
f"Expected {hidden_size=} to be aligned to {subc_quant_wsz=}.")
|
|
1291
|
+
if intermediate_size % subc_quant_wsz != 0:
|
|
1292
|
+
raise ValueError(
|
|
1293
|
+
f"Expected {intermediate_size=} to be aligned to {subc_quant_wsz=}."
|
|
1294
|
+
)
|
|
1295
|
+
# We force compute size of contracting dim to be subc_quant_wsz. So we can
|
|
1296
|
+
# apply same scale after matmul and accumulation.
|
|
1297
|
+
bd1c = subc_quant_wsz * t_packing
|
|
1298
|
+
bfc = subc_quant_wsz
|
|
1299
|
+
|
|
1300
|
+
if bfc % 128 != 0:
|
|
1301
|
+
raise ValueError(f"Expected {bfc=} to be aligned to 128.")
|
|
1302
|
+
if bd1c % (t_packing * 128) != 0:
|
|
1303
|
+
raise ValueError(
|
|
1304
|
+
f"Expected {bd1c=} to be aligned to {t_packing * 128}.")
|
|
1305
|
+
if bd2c % (t_packing * 128) != 0:
|
|
1306
|
+
raise ValueError(
|
|
1307
|
+
f"Expected {bd2c=} to be aligned to {t_packing * 128}.")
|
|
1308
|
+
if bf % bfc != 0:
|
|
1309
|
+
raise ValueError(f"Expected {bf=} to be aligned to {bfc=}.")
|
|
1310
|
+
if bd1 % bd1c != 0:
|
|
1311
|
+
raise ValueError(f"Expected {bd1=} to be aligned to {bd1c=}.")
|
|
1312
|
+
if bd2 % bd2c != 0:
|
|
1313
|
+
raise ValueError(f"Expected {bd2=} to be aligned to {bd2c=}.")
|
|
1314
|
+
if hidden_size % bd1 != 0 or hidden_size % bd2 != 0:
|
|
1315
|
+
raise ValueError(
|
|
1316
|
+
f"Expected {hidden_size=} to be aligned to {bd1=} and {bd2=}.")
|
|
1317
|
+
if intermediate_size % bf != 0:
|
|
1318
|
+
raise ValueError(
|
|
1319
|
+
f"Expected {intermediate_size=} to be aligned to {bf=}.")
|
|
1320
|
+
|
|
1321
|
+
# Note: we should dump scale as the kernel expected shape in the
|
|
1322
|
+
# checkpoint offline or reshape right after weight loading.
|
|
1323
|
+
if w1_scale is not None:
|
|
1324
|
+
expected_w1_scale_shape = (
|
|
1325
|
+
num_experts,
|
|
1326
|
+
2,
|
|
1327
|
+
hidden_size // subc_quant_wsz,
|
|
1328
|
+
1,
|
|
1329
|
+
intermediate_size,
|
|
1330
|
+
)
|
|
1331
|
+
if w1_scale.shape != expected_w1_scale_shape:
|
|
1332
|
+
raise ValueError(
|
|
1333
|
+
f"Expected {w1_scale.shape=} to be {expected_w1_scale_shape}.")
|
|
1334
|
+
if w1_scale.dtype != jnp.float32:
|
|
1335
|
+
w1_scale = w1_scale.astype(jnp.float32)
|
|
1336
|
+
|
|
1337
|
+
if w2_scale is not None:
|
|
1338
|
+
expected_w2_scale_shape = (
|
|
1339
|
+
num_experts,
|
|
1340
|
+
intermediate_size // subc_quant_wsz,
|
|
1341
|
+
1,
|
|
1342
|
+
hidden_size,
|
|
1343
|
+
)
|
|
1344
|
+
if w2_scale.shape != expected_w2_scale_shape:
|
|
1345
|
+
raise ValueError(
|
|
1346
|
+
f"Expected {w2_scale.shape=} to be {expected_w2_scale_shape}.")
|
|
1347
|
+
if w2_scale.dtype != jnp.float32:
|
|
1348
|
+
w2_scale = w2_scale.astype(jnp.float32)
|
|
1349
|
+
|
|
1350
|
+
if b1 is not None:
|
|
1351
|
+
expected_b1_shape = (num_experts, 2, 1, intermediate_size)
|
|
1352
|
+
if b1.shape != expected_b1_shape:
|
|
1353
|
+
raise ValueError(
|
|
1354
|
+
f"Expected {b1.shape=} to be {expected_b1_shape}.")
|
|
1355
|
+
if b1.dtype != jnp.float32:
|
|
1356
|
+
b1 = b1.astype(jnp.float32)
|
|
1357
|
+
|
|
1358
|
+
if b2 is not None:
|
|
1359
|
+
expected_b2_shape = (num_experts, 1, hidden_size)
|
|
1360
|
+
if b2.shape != expected_b2_shape:
|
|
1361
|
+
raise ValueError(
|
|
1362
|
+
f"Expected {b2.shape=} to be {expected_b2_shape}.")
|
|
1363
|
+
if b2.dtype != jnp.float32:
|
|
1364
|
+
b2 = b2.astype(jnp.float32)
|
|
1365
|
+
|
|
1366
|
+
# Prepare inputs for the kernel.
|
|
904
1367
|
if padded_num_experts != gating_output.shape[-1]:
|
|
905
1368
|
gating_output = jnp.pad(
|
|
906
1369
|
gating_output,
|
|
@@ -908,13 +1371,20 @@ def fused_ep_moe(
|
|
|
908
1371
|
constant_values=-jnp.inf,
|
|
909
1372
|
)
|
|
910
1373
|
|
|
911
|
-
|
|
1374
|
+
tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing)
|
|
1375
|
+
|
|
1376
|
+
hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM)
|
|
1377
|
+
renorm_str = "-renorm_k" if renormalize_topk_logits else ""
|
|
1378
|
+
scope_name = f"fused-moe-k_{top_k}{renorm_str}-bt_{bt}_{btc}-bf_{bf}_{bfc}-bd1_{bd1}_{bd1c}-bd2_{bd2}_{bd2c}"
|
|
912
1379
|
fused_moe = jax.named_scope(scope_name)(
|
|
913
1380
|
pl.pallas_call(
|
|
914
1381
|
functools.partial(
|
|
915
1382
|
_fused_ep_moe_kernel,
|
|
916
1383
|
top_k=top_k,
|
|
1384
|
+
renormalize_topk_logits=renormalize_topk_logits,
|
|
917
1385
|
ep_axis_name=ep_axis_name,
|
|
1386
|
+
act_fn=act_fn,
|
|
1387
|
+
subc_quant_wsz=subc_quant_wsz,
|
|
918
1388
|
bt=bt,
|
|
919
1389
|
bf=bf,
|
|
920
1390
|
bd1=bd1,
|
|
@@ -929,16 +1399,22 @@ def fused_ep_moe(
|
|
|
929
1399
|
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
930
1400
|
num_scalar_prefetch=0,
|
|
931
1401
|
in_specs=[
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
1402
|
+
hbm_block_spec, # tokens_hbm
|
|
1403
|
+
hbm_block_spec, # w1_hbm
|
|
1404
|
+
hbm_block_spec, # w2_hbm
|
|
1405
|
+
None
|
|
1406
|
+
if w1_scale is None else hbm_block_spec, # w1_scale_hbm
|
|
1407
|
+
None
|
|
1408
|
+
if w2_scale is None else hbm_block_spec, # w2_scale_hbm
|
|
1409
|
+
None if b1 is None else hbm_block_spec, # b1_hbm
|
|
1410
|
+
None if b2 is None else hbm_block_spec, # b2_hbm
|
|
1411
|
+
hbm_block_spec, # gating_output_hbm
|
|
1412
|
+
hbm_block_spec, # a2a_g_hbm
|
|
937
1413
|
],
|
|
938
1414
|
out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
|
|
939
1415
|
scratch_shapes=([
|
|
940
1416
|
# t2e_routing_x2_smem
|
|
941
|
-
pltpu.SMEM((2, bt,
|
|
1417
|
+
pltpu.SMEM((2, bt, padded_top_k), jnp.int32),
|
|
942
1418
|
# d2e_count_x2_smem
|
|
943
1419
|
pltpu.SMEM((2, num_devices, 1, padded_num_experts),
|
|
944
1420
|
jnp.int32),
|
|
@@ -984,6 +1460,67 @@ def fused_ep_moe(
|
|
|
984
1460
|
pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
|
|
985
1461
|
# b_w2_x2_vmem
|
|
986
1462
|
pltpu.VMEM((2, t_packing, bf, bd2 // t_packing), w2.dtype),
|
|
1463
|
+
# b_w1_scale_x2_vmem
|
|
1464
|
+
(None if w1_scale is None else pltpu.VMEM(
|
|
1465
|
+
(
|
|
1466
|
+
2,
|
|
1467
|
+
t_packing,
|
|
1468
|
+
bd1 // t_packing // subc_quant_wsz,
|
|
1469
|
+
1,
|
|
1470
|
+
bf,
|
|
1471
|
+
),
|
|
1472
|
+
jnp.float32,
|
|
1473
|
+
)),
|
|
1474
|
+
# b_w3_scale_x2_vmem
|
|
1475
|
+
(None if w1_scale is None else pltpu.VMEM(
|
|
1476
|
+
(
|
|
1477
|
+
2,
|
|
1478
|
+
t_packing,
|
|
1479
|
+
bd1 // t_packing // subc_quant_wsz,
|
|
1480
|
+
1,
|
|
1481
|
+
bf,
|
|
1482
|
+
),
|
|
1483
|
+
jnp.float32,
|
|
1484
|
+
)),
|
|
1485
|
+
# b_w2_scale_x2_vmem
|
|
1486
|
+
(None if w2_scale is None else pltpu.VMEM(
|
|
1487
|
+
(
|
|
1488
|
+
2,
|
|
1489
|
+
t_packing,
|
|
1490
|
+
bf // subc_quant_wsz,
|
|
1491
|
+
1,
|
|
1492
|
+
bd2 // t_packing,
|
|
1493
|
+
),
|
|
1494
|
+
jnp.float32,
|
|
1495
|
+
)),
|
|
1496
|
+
# b_b1_x2_vmem
|
|
1497
|
+
(None if b1 is None else pltpu.VMEM(
|
|
1498
|
+
(
|
|
1499
|
+
2,
|
|
1500
|
+
1,
|
|
1501
|
+
bf,
|
|
1502
|
+
),
|
|
1503
|
+
jnp.float32,
|
|
1504
|
+
)),
|
|
1505
|
+
# b_b3_x2_vmem
|
|
1506
|
+
(None if b1 is None else pltpu.VMEM(
|
|
1507
|
+
(
|
|
1508
|
+
2,
|
|
1509
|
+
1,
|
|
1510
|
+
bf,
|
|
1511
|
+
),
|
|
1512
|
+
jnp.float32,
|
|
1513
|
+
)),
|
|
1514
|
+
# b_b2_x2_vmem
|
|
1515
|
+
(None if b2 is None else pltpu.VMEM(
|
|
1516
|
+
(
|
|
1517
|
+
2,
|
|
1518
|
+
t_packing,
|
|
1519
|
+
1,
|
|
1520
|
+
bd2 // t_packing,
|
|
1521
|
+
),
|
|
1522
|
+
jnp.float32,
|
|
1523
|
+
)),
|
|
987
1524
|
# b_acc_vmem
|
|
988
1525
|
pltpu.VMEM((bt * num_devices, 1, bf * 2), jnp.float32),
|
|
989
1526
|
# local_sems
|
|
@@ -1006,30 +1543,62 @@ def fused_ep_moe(
|
|
|
1006
1543
|
))
|
|
1007
1544
|
|
|
1008
1545
|
@jax.jit
|
|
1009
|
-
@
|
|
1010
|
-
shard_map.shard_map,
|
|
1546
|
+
@jax.shard_map(
|
|
1011
1547
|
mesh=mesh,
|
|
1012
|
-
in_specs=(
|
|
1013
|
-
|
|
1548
|
+
in_specs=(
|
|
1549
|
+
P(ep_axis_name), # tokens_hbm
|
|
1550
|
+
P(ep_axis_name), # w1_hbm
|
|
1551
|
+
P(ep_axis_name), # w2_hbm
|
|
1552
|
+
None if w1_scale is None else P(ep_axis_name), # w1_scale_hbm
|
|
1553
|
+
None if w2_scale is None else P(ep_axis_name), # w2_scale_hbm
|
|
1554
|
+
None if b1 is None else P(ep_axis_name), # b1_hbm
|
|
1555
|
+
None if b2 is None else P(ep_axis_name), # b2_hbm
|
|
1556
|
+
P(ep_axis_name), # gating_output_hbm
|
|
1557
|
+
P(), # a2a_g_hbm
|
|
1558
|
+
),
|
|
1014
1559
|
out_specs=P(ep_axis_name),
|
|
1015
|
-
|
|
1560
|
+
check_vma=False,
|
|
1016
1561
|
)
|
|
1017
|
-
def kernel(
|
|
1562
|
+
def kernel(
|
|
1563
|
+
tokens,
|
|
1564
|
+
w1,
|
|
1565
|
+
w2,
|
|
1566
|
+
w1_scale,
|
|
1567
|
+
w2_scale,
|
|
1568
|
+
b1,
|
|
1569
|
+
b2,
|
|
1570
|
+
gating_output,
|
|
1571
|
+
a2a_g_hbm_scratch,
|
|
1572
|
+
):
|
|
1018
1573
|
return fused_moe(
|
|
1019
|
-
pltpu.with_memory_space_constraint(tokens,
|
|
1020
|
-
|
|
1021
|
-
pltpu.with_memory_space_constraint(
|
|
1022
|
-
pltpu.with_memory_space_constraint(
|
|
1023
|
-
pltpu.with_memory_space_constraint(
|
|
1574
|
+
pltpu.with_memory_space_constraint(tokens,
|
|
1575
|
+
pltpu.HBM), # tokens_hbm
|
|
1576
|
+
pltpu.with_memory_space_constraint(w1, pltpu.HBM), # w1_hbm
|
|
1577
|
+
pltpu.with_memory_space_constraint(w2, pltpu.HBM), # w2_hbm
|
|
1578
|
+
(None if w1_scale is None else pltpu.with_memory_space_constraint(
|
|
1579
|
+
w1_scale, pltpu.HBM)), # w1_scale_hbm
|
|
1580
|
+
(None if w2_scale is None else pltpu.with_memory_space_constraint(
|
|
1581
|
+
w2_scale, pltpu.HBM)), # w2_scale_hbm
|
|
1582
|
+
(None if b1 is None else pltpu.with_memory_space_constraint(
|
|
1583
|
+
b1, pltpu.HBM)), # b1_hbm
|
|
1584
|
+
(None if b2 is None else pltpu.with_memory_space_constraint(
|
|
1585
|
+
b2, pltpu.HBM)), # b2_hbm
|
|
1586
|
+
pltpu.with_memory_space_constraint(gating_output,
|
|
1587
|
+
pltpu.HBM), # gating_output_hbm
|
|
1588
|
+
pltpu.with_memory_space_constraint(a2a_g_hbm_scratch,
|
|
1589
|
+
pltpu.HBM), # a2a_g_hbm
|
|
1024
1590
|
)
|
|
1025
1591
|
|
|
1026
1592
|
a2a_g_hbm_scratch = pl.empty(
|
|
1027
1593
|
(num_experts, bt, t_packing, hidden_size // t_packing), t_dtype)
|
|
1028
|
-
|
|
1594
|
+
return kernel(
|
|
1029
1595
|
tokens,
|
|
1030
1596
|
w1,
|
|
1031
1597
|
w2,
|
|
1598
|
+
w1_scale,
|
|
1599
|
+
w2_scale,
|
|
1600
|
+
b1,
|
|
1601
|
+
b2,
|
|
1032
1602
|
gating_output,
|
|
1033
1603
|
a2a_g_hbm_scratch,
|
|
1034
1604
|
)
|
|
1035
|
-
return results[:, :actual_hidden_size]
|