tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +34 -303
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
- tests/lora/test_layers.py +6 -0
- tests/lora/utils.py +8 -0
- tests/test_envs.py +11 -32
- tests/test_utils.py +2 -1
- tpu_inference/__init__.py +3 -22
- tpu_inference/core/disagg_utils.py +8 -6
- tpu_inference/distributed/tpu_connector.py +4 -3
- tpu_inference/distributed/utils.py +2 -3
- tpu_inference/envs.py +8 -61
- tpu_inference/executors/ray_distributed_executor.py +2 -9
- tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +145 -266
- tpu_inference/layers/common/attention_interface.py +1 -7
- tpu_inference/layers/common/sharding.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +208 -170
- tpu_inference/layers/vllm/quantization/common.py +1 -6
- tpu_inference/layers/vllm/quantization/mxfp4.py +73 -138
- tpu_inference/layers/vllm/quantization/unquantized.py +64 -58
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +2 -1
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/common/model_loader.py +10 -43
- tpu_inference/models/jax/llama3.py +1 -2
- tpu_inference/models/jax/llama_eagle3.py +5 -8
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +1 -2
- tpu_inference/models/jax/qwen2_5_vl.py +48 -163
- tpu_inference/models/jax/qwen3.py +1 -2
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
- tpu_inference/models/jax/utils/weight_utils.py +143 -198
- tpu_inference/models/vllm/vllm_model_wrapper.py +8 -14
- tpu_inference/platforms/tpu_platform.py +31 -37
- tpu_inference/runner/compilation_manager.py +58 -141
- tpu_inference/runner/kv_cache.py +1 -1
- tpu_inference/runner/kv_cache_manager.py +18 -17
- tpu_inference/runner/persistent_batch_manager.py +2 -40
- tpu_inference/runner/structured_decoding_manager.py +3 -2
- tpu_inference/runner/tpu_runner.py +147 -271
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +21 -71
- tpu_inference/tpu_info.py +3 -4
- tpu_inference/utils.py +13 -36
- tpu_inference/worker/tpu_worker.py +25 -162
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +3 -4
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +55 -50
- tpu_inference/models/jax/llama_guard_4.py +0 -361
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
|
@@ -7,6 +7,7 @@ import jax.numpy as jnp
|
|
|
7
7
|
from jax import lax
|
|
8
8
|
from jax._src import dtypes
|
|
9
9
|
from jax.experimental import pallas as pl
|
|
10
|
+
from jax.experimental import shard_map
|
|
10
11
|
from jax.experimental.pallas import tpu as pltpu
|
|
11
12
|
|
|
12
13
|
P = jax.sharding.PartitionSpec
|
|
@@ -34,49 +35,13 @@ def broadcast_minor(src, shape):
|
|
|
34
35
|
axis=-1)[..., :shape[-1]]
|
|
35
36
|
|
|
36
37
|
|
|
37
|
-
def swigluoai(gate: jax.Array,
|
|
38
|
-
up: jax.Array,
|
|
39
|
-
*,
|
|
40
|
-
alpha: float = 1.702,
|
|
41
|
-
limit: float = 7.0) -> jax.Array:
|
|
42
|
-
"""Activation used in some models such as GPT-OSS."""
|
|
43
|
-
gate = jnp.clip(gate, a_max=limit)
|
|
44
|
-
up = jnp.clip(up, a_min=-limit, a_max=limit)
|
|
45
|
-
glu = gate * jax.nn.sigmoid(alpha * gate)
|
|
46
|
-
return (up + 1.0) * glu
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
def activation_fn(acc1, acc3, act_fn):
|
|
50
|
-
if act_fn == "silu":
|
|
51
|
-
return jax.nn.silu(acc1) * acc3
|
|
52
|
-
elif act_fn == "gelu":
|
|
53
|
-
return jax.nn.gelu(acc1) * acc3
|
|
54
|
-
elif act_fn == "swigluoai":
|
|
55
|
-
return swigluoai(acc1, acc3)
|
|
56
|
-
else:
|
|
57
|
-
raise RuntimeError(f"Unsupported activation function: {act_fn}")
|
|
58
|
-
|
|
59
|
-
|
|
60
38
|
def ref_moe(
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
renormalize_topk_logits: bool = False,
|
|
68
|
-
activation="silu",
|
|
69
|
-
subc_quant_wsz: int | None = None,
|
|
70
|
-
w1_scale:
|
|
71
|
-
(
|
|
72
|
-
jax.Array | None
|
|
73
|
-
) = None, # (num_experts, 2, cdiv(hidden_size, subc_quant_wsz), intermediate_size)
|
|
74
|
-
w2_scale:
|
|
75
|
-
(
|
|
76
|
-
jax.Array | None
|
|
77
|
-
) = None, # (num_experts, cdiv(intermediate_size, subc_quant_wsz), hidden_size)
|
|
78
|
-
b1: jax.Array | None = None, # (num_experts, 2, intermediate_size)
|
|
79
|
-
b2: jax.Array | None = None, # (num_experts, hidden_size)
|
|
39
|
+
tokens: jax.Array, # (num_tokens, hidden_size)
|
|
40
|
+
w1: jax.Array, # (num_experts, 2, hidden_size, intermediate_size)
|
|
41
|
+
w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
|
|
42
|
+
gating_output: jax.Array, # (num_tokens, num_experts)
|
|
43
|
+
top_k: int,
|
|
44
|
+
activation="silu",
|
|
80
45
|
):
|
|
81
46
|
n_tokens = tokens.shape[0] # num_tokens
|
|
82
47
|
|
|
@@ -88,12 +53,7 @@ def ref_moe(
|
|
|
88
53
|
top_k_logits, top_k_indices = lax.top_k(
|
|
89
54
|
gating_logits, top_k) # [num_tokens, top_k], [num_tokens, top_k]
|
|
90
55
|
|
|
91
|
-
if renormalize_topk_logits:
|
|
92
|
-
top_k_logits = top_k_logits / jnp.sum(
|
|
93
|
-
top_k_logits, axis=-1, keepdims=True)
|
|
94
|
-
|
|
95
56
|
t_outputs = []
|
|
96
|
-
hidden_size, intermediate_size = w1.shape[-2:]
|
|
97
57
|
|
|
98
58
|
# Process each token individually
|
|
99
59
|
for i in range(n_tokens):
|
|
@@ -105,24 +65,10 @@ def ref_moe(
|
|
|
105
65
|
# Process each selected expert for the current token
|
|
106
66
|
for expert_id in assigned_expert_ids:
|
|
107
67
|
# Get expert weights
|
|
108
|
-
expert_w1 = w1[expert_id, 0].astype(jnp.float32)
|
|
109
|
-
expert_w3 = w1[expert_id, 1].astype(jnp.float32)
|
|
110
|
-
if w1_scale is not None:
|
|
111
|
-
expert_w1 *= jnp.repeat(w1_scale[expert_id, 0],
|
|
112
|
-
subc_quant_wsz,
|
|
113
|
-
axis=0)[:hidden_size]
|
|
114
|
-
expert_w3 *= jnp.repeat(w1_scale[expert_id, 1],
|
|
115
|
-
subc_quant_wsz,
|
|
116
|
-
axis=0)[:hidden_size]
|
|
117
68
|
expert_weight_1 = jnp.concat(
|
|
118
|
-
[
|
|
69
|
+
[w1[expert_id, 0], w1[expert_id, 1]],
|
|
119
70
|
axis=-1) # [d_model, 2 * intermediate_size]
|
|
120
|
-
expert_weight_2 = w2[expert_id]
|
|
121
|
-
jnp.float32) # [intermediate_size, d_model]
|
|
122
|
-
if w2_scale is not None:
|
|
123
|
-
expert_weight_2 *= jnp.repeat(w2_scale[expert_id],
|
|
124
|
-
subc_quant_wsz,
|
|
125
|
-
axis=0)[:intermediate_size]
|
|
71
|
+
expert_weight_2 = w2[expert_id] # [intermediate_size, d_model]
|
|
126
72
|
|
|
127
73
|
# First linear layer with SwiGLU activation
|
|
128
74
|
gmm_1_out = curr_token @ expert_weight_1 # [1, 2 * intermediate_size]
|
|
@@ -131,17 +77,21 @@ def ref_moe(
|
|
|
131
77
|
gmm1_w1_proj, gmm1_w3_proj = jnp.split(
|
|
132
78
|
gmm_1_out, 2,
|
|
133
79
|
axis=-1) # [1, intermediate_size], [1, intermediate_size]
|
|
134
|
-
if b1 is not None:
|
|
135
|
-
gmm1_w1_proj += b1[expert_id:expert_id + 1, 0]
|
|
136
|
-
gmm1_w3_proj += b1[expert_id:expert_id + 1, 1]
|
|
137
80
|
|
|
138
81
|
# Apply gated activation: activation(gate) * up
|
|
139
|
-
|
|
82
|
+
if activation == "silu":
|
|
83
|
+
act = jax.nn.silu(
|
|
84
|
+
gmm1_w1_proj) * gmm1_w3_proj # [1, intermediate_size]
|
|
85
|
+
elif activation == "gelu":
|
|
86
|
+
act = jax.nn.gelu(
|
|
87
|
+
gmm1_w1_proj) * gmm1_w3_proj # [1, intermediate_size]
|
|
88
|
+
else:
|
|
89
|
+
raise ValueError(
|
|
90
|
+
f"Unsupported activation: {activation}. Use 'silu' or 'gelu'."
|
|
91
|
+
)
|
|
140
92
|
|
|
141
93
|
# Second linear layer (down projection)
|
|
142
94
|
gmm_2_out = act @ expert_weight_2 # [1, d_model]
|
|
143
|
-
if b2 is not None:
|
|
144
|
-
gmm_2_out += b2[expert_id:expert_id + 1]
|
|
145
95
|
tok_expert_act.append(gmm_2_out)
|
|
146
96
|
|
|
147
97
|
# Combine outputs from all selected experts
|
|
@@ -155,7 +105,7 @@ def ref_moe(
|
|
|
155
105
|
axis=0,
|
|
156
106
|
keepdims=True) # [1, d_model]
|
|
157
107
|
|
|
158
|
-
t_outputs.append(weighted_output
|
|
108
|
+
t_outputs.append(weighted_output)
|
|
159
109
|
|
|
160
110
|
return jnp.concatenate(t_outputs, axis=0) # [num_tokens, d_model]
|
|
161
111
|
|
|
@@ -165,13 +115,6 @@ def _fused_ep_moe_kernel(
|
|
|
165
115
|
tokens_hbm, # (local_num_tokens, t_packing, hidden_size // t_packing)
|
|
166
116
|
w1_hbm, # (local_num_experts, 2, hidden_size, intermediate_size)
|
|
167
117
|
w2_hbm, # (local_num_experts, intermediate_size, hidden_size)
|
|
168
|
-
# TODO(jevinjiang): We choose F32 scale for easier slicing. The extra
|
|
169
|
-
# latency should be hidden in the pipeline overlaping. But is there a better
|
|
170
|
-
# way to do this?
|
|
171
|
-
w1_scale_hbm, # None | F32(local_num_experts, 2, cdiv(hidden_size, subc_quant_wsz), 1, intermediate_size)
|
|
172
|
-
w2_scale_hbm, # None | F32(local_num_experts, cdiv(intermediate_size, subc_quant_wsz), 1, hidden_size)
|
|
173
|
-
b1_hbm, # None | F32(local_num_experts, 2, 1, intermediate_size)
|
|
174
|
-
b2_hbm, # None | F32(local_num_experts, 1, hidden_size)
|
|
175
118
|
gating_hbm, # (local_num_tokens, padded_num_experts)
|
|
176
119
|
a2a_g_hbm, # (num_experts, bt, t_packing, hidden_size // t_packing)
|
|
177
120
|
# Output
|
|
@@ -193,12 +136,6 @@ def _fused_ep_moe_kernel(
|
|
|
193
136
|
b_w1_x2_vmem, # <bw_sem_id> (2, t_packing, bd1 // t_packing, bf)
|
|
194
137
|
b_w3_x2_vmem, # <bw_sem_id> (2, t_packing, bd1 // t_packing, bf)
|
|
195
138
|
b_w2_x2_vmem, # <bw_sem_id> (2, t_packing, bf, bd2 // t_packing)
|
|
196
|
-
b_w1_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf)
|
|
197
|
-
b_w3_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf)
|
|
198
|
-
b_w2_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bf // subc_quant_wsz, 1, bd2 // t_packing)
|
|
199
|
-
b_b1_x2_vmem, # None | <bw_sem_id> (2, 1, bf)
|
|
200
|
-
b_b3_x2_vmem, # None | <bw_sem_id> (2, 1, bf)
|
|
201
|
-
b_b2_x2_vmem, # None | <bw_sem_id> (2, t_packing, 1, bd2 // t_packing)
|
|
202
139
|
b_acc_vmem, # F32(bt * num_devices, 1, bf * 2)
|
|
203
140
|
### Semaphores:
|
|
204
141
|
local_sems, # (2, 5): 2 x [b_gating_sem, b_w1_sem, b_w2_sem, b_w3_sem, b_output_sem]
|
|
@@ -208,10 +145,7 @@ def _fused_ep_moe_kernel(
|
|
|
208
145
|
a2a_acc_sem,
|
|
209
146
|
*,
|
|
210
147
|
top_k: int,
|
|
211
|
-
renormalize_topk_logits: bool,
|
|
212
148
|
ep_axis_name: str,
|
|
213
|
-
act_fn: str,
|
|
214
|
-
subc_quant_wsz: int | None = None,
|
|
215
149
|
# Kernel tuning params.
|
|
216
150
|
bt: int, # Block size of local_num_tokens.
|
|
217
151
|
bf: int, # Block size of intermediate_size.
|
|
@@ -226,53 +160,34 @@ def _fused_ep_moe_kernel(
|
|
|
226
160
|
num_devices = lax.axis_size(ep_axis_name)
|
|
227
161
|
local_num_tokens = tokens_hbm.shape[0]
|
|
228
162
|
local_num_experts, intermediate_size, hidden_size = w2_hbm.shape
|
|
163
|
+
# num_experts = local_num_experts * num_devices
|
|
164
|
+
# padded_num_experts = expert_starts_x2_smem.shape[-1]
|
|
229
165
|
right_id = (my_id + 1) % num_devices
|
|
230
166
|
|
|
231
167
|
t_dtype = tokens_hbm.dtype
|
|
232
168
|
t_packing = get_dtype_packing(t_dtype)
|
|
233
169
|
t_bitwidth = 32 // t_packing
|
|
234
170
|
assert a2a_g_hbm.dtype == t_dtype
|
|
235
|
-
assert w1_hbm.dtype ==
|
|
171
|
+
assert w1_hbm.dtype == t_dtype
|
|
172
|
+
assert w2_hbm.dtype == t_dtype
|
|
236
173
|
|
|
237
|
-
|
|
238
|
-
assert
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
assert bd1c % t_packing == 0
|
|
244
|
-
assert bd2c % t_packing == 0
|
|
245
|
-
|
|
246
|
-
h_per_t_packing = hidden_size // t_packing
|
|
247
|
-
assert tokens_hbm.shape[-1] == h_per_t_packing
|
|
248
|
-
bd1_per_t_packing = bd1 // t_packing
|
|
249
|
-
bd2_per_t_packing = bd2 // t_packing
|
|
250
|
-
bd1c_per_t_packing = bd1c // t_packing
|
|
251
|
-
bd2c_per_t_packing = bd2c // t_packing
|
|
252
|
-
|
|
253
|
-
if subc_quant_wsz is not None:
|
|
254
|
-
assert subc_quant_wsz % 256 == 0
|
|
255
|
-
assert bd1c_per_t_packing == subc_quant_wsz
|
|
256
|
-
assert bfc == subc_quant_wsz
|
|
257
|
-
assert bd1 % subc_quant_wsz == 0
|
|
258
|
-
assert bf % subc_quant_wsz == 0
|
|
259
|
-
assert bd1_per_t_packing % subc_quant_wsz == 0
|
|
260
|
-
assert h_per_t_packing % subc_quant_wsz == 0
|
|
174
|
+
h_per_packing = hidden_size // t_packing
|
|
175
|
+
assert tokens_hbm.shape[-1] == h_per_packing
|
|
176
|
+
bd1_per_packing = bd1 // t_packing
|
|
177
|
+
bd2_per_packing = bd2 // t_packing
|
|
178
|
+
bd1c_per_packing = bd1c // t_packing
|
|
179
|
+
bd2c_per_packing = bd2c // t_packing
|
|
261
180
|
|
|
262
181
|
num_bt = cdiv(local_num_tokens, bt)
|
|
263
182
|
num_bf = cdiv(intermediate_size, bf)
|
|
264
183
|
num_bd1 = cdiv(hidden_size, bd1)
|
|
265
184
|
num_bd2 = cdiv(hidden_size, bd2)
|
|
266
185
|
|
|
267
|
-
def get_mesh_device_id(ep_rank):
|
|
268
|
-
dp_rank = jax.lax.axis_index("data")
|
|
269
|
-
return (dp_rank, ep_rank)
|
|
270
|
-
|
|
271
186
|
def sync_barrier():
|
|
272
187
|
barrier_sem = pltpu.get_barrier_semaphore()
|
|
273
188
|
pltpu.semaphore_signal(
|
|
274
189
|
barrier_sem,
|
|
275
|
-
device_id=
|
|
190
|
+
device_id=(0, right_id),
|
|
276
191
|
device_id_type=pltpu.DeviceIdType.MESH,
|
|
277
192
|
)
|
|
278
193
|
pltpu.semaphore_wait(barrier_sem, 1)
|
|
@@ -297,7 +212,7 @@ def _fused_ep_moe_kernel(
|
|
|
297
212
|
sem=b_gating_sem,
|
|
298
213
|
).wait()
|
|
299
214
|
|
|
300
|
-
def get_top_k(input, top_k
|
|
215
|
+
def get_top_k(input, top_k):
|
|
301
216
|
assert len(input.shape) == 2, input.shape
|
|
302
217
|
input = input.astype(jnp.float32)
|
|
303
218
|
top_k_logits_lst = []
|
|
@@ -305,15 +220,11 @@ def _fused_ep_moe_kernel(
|
|
|
305
220
|
t2e = jnp.zeros(input.shape, dtype=jnp.int32)
|
|
306
221
|
t2e_routing = jnp.zeros(input.shape, dtype=jnp.int32)
|
|
307
222
|
iota = jax.lax.broadcasted_iota(jnp.int32, input.shape, 1)
|
|
308
|
-
top_k_logits_sum = jnp.zeros((input.shape[0], 128), jnp.float32)
|
|
309
|
-
|
|
310
223
|
for k_id in range(top_k):
|
|
311
|
-
# TODO(jevinjiang): return both top_k values and indices in Mosaic
|
|
224
|
+
# TODO(jevinjiang): return both top_k values and indices in op in Mosaic
|
|
312
225
|
top_k_logits = jnp.broadcast_to(
|
|
313
226
|
jnp.max(input, axis=1, keepdims=True),
|
|
314
227
|
(input.shape[0], 128)).astype(input.dtype)
|
|
315
|
-
if renormalize_topk_logits:
|
|
316
|
-
top_k_logits_sum += top_k_logits
|
|
317
228
|
top_k_logits_lst.append(top_k_logits)
|
|
318
229
|
# TODO(jevinjiang): support bf16 argmax in Mosaic
|
|
319
230
|
top_k_indices = jnp.broadcast_to(
|
|
@@ -325,11 +236,6 @@ def _fused_ep_moe_kernel(
|
|
|
325
236
|
if k_id != top_k - 1:
|
|
326
237
|
input = jnp.where(mask, -jnp.inf, input)
|
|
327
238
|
|
|
328
|
-
if renormalize_topk_logits:
|
|
329
|
-
for k_id in range(top_k):
|
|
330
|
-
top_k_logits_lst[
|
|
331
|
-
k_id] = top_k_logits_lst[k_id] / top_k_logits_sum
|
|
332
|
-
|
|
333
239
|
expert_sizes = jnp.sum(t2e, axis=0, keepdims=True)
|
|
334
240
|
expert_starts = jnp.zeros_like(expert_sizes)
|
|
335
241
|
return top_k_logits_lst, t2e_routing, expert_sizes, expert_starts
|
|
@@ -371,7 +277,7 @@ def _fused_ep_moe_kernel(
|
|
|
371
277
|
dst_ref=d2e_count_vmem.at[row_id],
|
|
372
278
|
send_sem=send_sem,
|
|
373
279
|
recv_sem=recv_sem,
|
|
374
|
-
device_id=
|
|
280
|
+
device_id=(0, right_id),
|
|
375
281
|
device_id_type=pltpu.DeviceIdType.MESH,
|
|
376
282
|
).wait()
|
|
377
283
|
row_id = (row_id + num_devices - 1) % num_devices
|
|
@@ -453,8 +359,10 @@ def _fused_ep_moe_kernel(
|
|
|
453
359
|
pl.ds(start, remote_sz)],
|
|
454
360
|
send_sem=send_sems.at[e_sem_id],
|
|
455
361
|
recv_sem=recv_sems.at[e_sem_id],
|
|
456
|
-
device_id=
|
|
457
|
-
|
|
362
|
+
device_id=(
|
|
363
|
+
0,
|
|
364
|
+
recv_id,
|
|
365
|
+
),
|
|
458
366
|
).start()
|
|
459
367
|
a2a_s_sends_x2_smem[e_sem_id] = send_sz
|
|
460
368
|
|
|
@@ -498,8 +406,7 @@ def _fused_ep_moe_kernel(
|
|
|
498
406
|
dst_ref=a2a_g_hbm.at[my_e_id, pl.ds(0, remote_sz)],
|
|
499
407
|
send_sem=send_sems.at[e_sem_id],
|
|
500
408
|
recv_sem=a2a_gather_sem,
|
|
501
|
-
device_id=
|
|
502
|
-
device_id_type=pltpu.DeviceIdType.MESH,
|
|
409
|
+
device_id=(0, recv_id),
|
|
503
410
|
).start()
|
|
504
411
|
start += sz
|
|
505
412
|
|
|
@@ -528,173 +435,68 @@ def _fused_ep_moe_kernel(
|
|
|
528
435
|
|
|
529
436
|
def start_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id):
|
|
530
437
|
for p in range(t_packing):
|
|
531
|
-
offset = p *
|
|
438
|
+
offset = p * h_per_packing + bd1_id * bd1_per_packing
|
|
532
439
|
pltpu.make_async_copy(
|
|
533
440
|
src_ref=w1_hbm.at[
|
|
534
441
|
local_e_id,
|
|
535
442
|
0,
|
|
536
|
-
pl.ds(offset,
|
|
443
|
+
pl.ds(offset, bd1_per_packing),
|
|
537
444
|
pl.ds(bf_id * bf, bf),
|
|
538
445
|
],
|
|
539
446
|
dst_ref=b_w1_x2_vmem.at[bw1_sem_id, p],
|
|
540
447
|
sem=local_sems.at[bw1_sem_id, 1],
|
|
541
448
|
).start()
|
|
542
|
-
if w1_scale_hbm is not None:
|
|
543
|
-
assert subc_quant_wsz is not None
|
|
544
|
-
pltpu.make_async_copy(
|
|
545
|
-
src_ref=w1_scale_hbm.at[
|
|
546
|
-
local_e_id,
|
|
547
|
-
0,
|
|
548
|
-
pl.ds(
|
|
549
|
-
offset // subc_quant_wsz,
|
|
550
|
-
bd1_per_t_packing // subc_quant_wsz,
|
|
551
|
-
),
|
|
552
|
-
pl.ds(0, 1),
|
|
553
|
-
pl.ds(bf_id * bf, bf),
|
|
554
|
-
],
|
|
555
|
-
dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id, p],
|
|
556
|
-
sem=local_sems.at[bw1_sem_id, 1],
|
|
557
|
-
).start()
|
|
558
|
-
if b1_hbm is not None and bd1_id == 0:
|
|
559
|
-
pltpu.make_async_copy(
|
|
560
|
-
src_ref=b1_hbm.at[local_e_id, 0,
|
|
561
|
-
pl.ds(0, 1),
|
|
562
|
-
pl.ds(bf_id * bf, bf)],
|
|
563
|
-
dst_ref=b_b1_x2_vmem.at[bf_id % 2],
|
|
564
|
-
sem=local_sems.at[bw1_sem_id, 1],
|
|
565
|
-
).start()
|
|
566
449
|
|
|
567
450
|
def start_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id):
|
|
568
451
|
for p in range(t_packing):
|
|
569
|
-
offset = p *
|
|
452
|
+
offset = p * h_per_packing + bd2_id * bd2_per_packing
|
|
570
453
|
pltpu.make_async_copy(
|
|
571
454
|
src_ref=w2_hbm.at[
|
|
572
455
|
local_e_id,
|
|
573
456
|
pl.ds(bf_id * bf, bf),
|
|
574
|
-
pl.ds(offset,
|
|
457
|
+
pl.ds(offset, bd2_per_packing),
|
|
575
458
|
],
|
|
576
459
|
dst_ref=b_w2_x2_vmem.at[bw2_sem_id, p],
|
|
577
460
|
sem=local_sems.at[bw2_sem_id, 2],
|
|
578
461
|
).start()
|
|
579
|
-
if w2_scale_hbm is not None:
|
|
580
|
-
assert subc_quant_wsz is not None
|
|
581
|
-
pltpu.make_async_copy(
|
|
582
|
-
src_ref=w2_scale_hbm.at[
|
|
583
|
-
local_e_id,
|
|
584
|
-
pl.ds(bf_id * bf // subc_quant_wsz, bf //
|
|
585
|
-
subc_quant_wsz),
|
|
586
|
-
pl.ds(0, 1),
|
|
587
|
-
pl.ds(offset, bd2_per_t_packing),
|
|
588
|
-
],
|
|
589
|
-
dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id, p],
|
|
590
|
-
sem=local_sems.at[bw2_sem_id, 2],
|
|
591
|
-
).start()
|
|
592
|
-
if b2_hbm is not None and bf_id == 0:
|
|
593
|
-
pltpu.make_async_copy(
|
|
594
|
-
src_ref=b2_hbm.at[local_e_id,
|
|
595
|
-
pl.ds(0, 1),
|
|
596
|
-
pl.ds(offset, bd2_per_t_packing)],
|
|
597
|
-
dst_ref=b_b2_x2_vmem.at[bd2_id % 2, p],
|
|
598
|
-
sem=local_sems.at[bw2_sem_id, 2],
|
|
599
|
-
).start()
|
|
600
462
|
|
|
601
463
|
def start_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id):
|
|
602
464
|
for p in range(t_packing):
|
|
603
|
-
offset = p *
|
|
465
|
+
offset = p * h_per_packing + bd3_id * bd1_per_packing
|
|
604
466
|
pltpu.make_async_copy(
|
|
605
467
|
src_ref=w1_hbm.at[
|
|
606
468
|
local_e_id,
|
|
607
469
|
1,
|
|
608
|
-
pl.ds(offset,
|
|
470
|
+
pl.ds(offset, bd1_per_packing),
|
|
609
471
|
pl.ds(bf_id * bf, bf),
|
|
610
472
|
],
|
|
611
473
|
dst_ref=b_w3_x2_vmem.at[bw3_sem_id, p],
|
|
612
474
|
sem=local_sems.at[bw3_sem_id, 3],
|
|
613
475
|
).start()
|
|
614
|
-
if w1_scale_hbm is not None:
|
|
615
|
-
assert subc_quant_wsz is not None
|
|
616
|
-
pltpu.make_async_copy(
|
|
617
|
-
src_ref=w1_scale_hbm.at[
|
|
618
|
-
local_e_id,
|
|
619
|
-
1,
|
|
620
|
-
pl.ds(
|
|
621
|
-
offset // subc_quant_wsz,
|
|
622
|
-
bd1_per_t_packing // subc_quant_wsz,
|
|
623
|
-
),
|
|
624
|
-
pl.ds(0, 1),
|
|
625
|
-
pl.ds(bf_id * bf, bf),
|
|
626
|
-
],
|
|
627
|
-
dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id, p],
|
|
628
|
-
sem=local_sems.at[bw3_sem_id, 3],
|
|
629
|
-
).start()
|
|
630
|
-
if b1_hbm is not None and bd3_id == 0:
|
|
631
|
-
pltpu.make_async_copy(
|
|
632
|
-
src_ref=b1_hbm.at[local_e_id, 1,
|
|
633
|
-
pl.ds(0, 1),
|
|
634
|
-
pl.ds(bf_id * bf, bf)],
|
|
635
|
-
dst_ref=b_b3_x2_vmem.at[bf_id % 2],
|
|
636
|
-
sem=local_sems.at[bw3_sem_id, 3],
|
|
637
|
-
).start()
|
|
638
476
|
|
|
639
477
|
def wait_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id):
|
|
640
|
-
del local_e_id
|
|
478
|
+
del local_e_id, bf_id, bd1_id
|
|
641
479
|
pltpu.make_async_copy(
|
|
642
480
|
src_ref=b_w1_x2_vmem.at[bw1_sem_id],
|
|
643
481
|
dst_ref=b_w1_x2_vmem.at[bw1_sem_id],
|
|
644
482
|
sem=local_sems.at[bw1_sem_id, 1],
|
|
645
483
|
).wait()
|
|
646
|
-
if w1_scale_hbm is not None:
|
|
647
|
-
pltpu.make_async_copy(
|
|
648
|
-
src_ref=b_w1_scale_x2_vmem.at[bw1_sem_id],
|
|
649
|
-
dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id],
|
|
650
|
-
sem=local_sems.at[bw1_sem_id, 1],
|
|
651
|
-
).wait()
|
|
652
|
-
if b1_hbm is not None and bd1_id == 0:
|
|
653
|
-
pltpu.make_async_copy(
|
|
654
|
-
src_ref=b_b1_x2_vmem.at[bf_id % 2],
|
|
655
|
-
dst_ref=b_b1_x2_vmem.at[bf_id % 2],
|
|
656
|
-
sem=local_sems.at[bw1_sem_id, 1],
|
|
657
|
-
).wait()
|
|
658
484
|
|
|
659
485
|
def wait_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id):
|
|
660
|
-
del local_e_id
|
|
486
|
+
del local_e_id, bf_id, bd2_id
|
|
661
487
|
pltpu.make_async_copy(
|
|
662
488
|
src_ref=b_w2_x2_vmem.at[bw2_sem_id],
|
|
663
489
|
dst_ref=b_w2_x2_vmem.at[bw2_sem_id],
|
|
664
490
|
sem=local_sems.at[bw2_sem_id, 2],
|
|
665
491
|
).wait()
|
|
666
|
-
if w2_scale_hbm is not None:
|
|
667
|
-
pltpu.make_async_copy(
|
|
668
|
-
src_ref=b_w2_scale_x2_vmem.at[bw2_sem_id],
|
|
669
|
-
dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id],
|
|
670
|
-
sem=local_sems.at[bw2_sem_id, 2],
|
|
671
|
-
).wait()
|
|
672
|
-
if b2_hbm is not None and bf_id == 0:
|
|
673
|
-
pltpu.make_async_copy(
|
|
674
|
-
src_ref=b_b2_x2_vmem.at[bd2_id % 2],
|
|
675
|
-
dst_ref=b_b2_x2_vmem.at[bd2_id % 2],
|
|
676
|
-
sem=local_sems.at[bw2_sem_id, 2],
|
|
677
|
-
).wait()
|
|
678
492
|
|
|
679
493
|
def wait_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id):
|
|
680
|
-
del local_e_id
|
|
494
|
+
del local_e_id, bf_id, bd3_id
|
|
681
495
|
pltpu.make_async_copy(
|
|
682
496
|
src_ref=b_w3_x2_vmem.at[bw3_sem_id],
|
|
683
497
|
dst_ref=b_w3_x2_vmem.at[bw3_sem_id],
|
|
684
498
|
sem=local_sems.at[bw3_sem_id, 3],
|
|
685
499
|
).wait()
|
|
686
|
-
if w1_scale_hbm is not None:
|
|
687
|
-
pltpu.make_async_copy(
|
|
688
|
-
src_ref=b_w3_scale_x2_vmem.at[bw3_sem_id],
|
|
689
|
-
dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id],
|
|
690
|
-
sem=local_sems.at[bw3_sem_id, 3],
|
|
691
|
-
).wait()
|
|
692
|
-
if b1_hbm is not None and bd3_id == 0:
|
|
693
|
-
pltpu.make_async_copy(
|
|
694
|
-
src_ref=b_b3_x2_vmem.at[bf_id % 2],
|
|
695
|
-
dst_ref=b_b3_x2_vmem.at[bf_id % 2],
|
|
696
|
-
sem=local_sems.at[bw3_sem_id, 3],
|
|
697
|
-
).wait()
|
|
698
500
|
|
|
699
501
|
def start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, bd2_id):
|
|
700
502
|
next_bd1_id = bd1_id + 1
|
|
@@ -718,38 +520,18 @@ def _fused_ep_moe_kernel(
|
|
|
718
520
|
def dynamic_ffn1(
|
|
719
521
|
t_b32_vmem,
|
|
720
522
|
w1_vmem,
|
|
721
|
-
w1_scale_vmem,
|
|
722
|
-
b1_vmem,
|
|
723
523
|
w3_vmem,
|
|
724
|
-
w3_scale_vmem,
|
|
725
|
-
b3_vmem,
|
|
726
524
|
acc1_vmem,
|
|
727
525
|
acc3_vmem,
|
|
728
526
|
dyn_sz,
|
|
729
527
|
should_init,
|
|
730
528
|
):
|
|
731
529
|
assert t_b32_vmem.shape == (bt * num_devices, bd1 // t_packing)
|
|
732
|
-
assert w1_vmem.shape == w3_vmem.shape == (t_packing,
|
|
530
|
+
assert w1_vmem.shape == w3_vmem.shape == (t_packing, bd1_per_packing,
|
|
733
531
|
bf)
|
|
734
532
|
assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf)
|
|
735
533
|
assert bd1 % (t_packing * 128) == 0, (bd1, t_packing)
|
|
736
534
|
assert bd1c % (t_packing * 128) == 0, (bd1c, t_packing)
|
|
737
|
-
if w1_scale_vmem is not None:
|
|
738
|
-
assert w1_scale_vmem.shape == (
|
|
739
|
-
t_packing,
|
|
740
|
-
bd1_per_t_packing // subc_quant_wsz,
|
|
741
|
-
1,
|
|
742
|
-
bf,
|
|
743
|
-
)
|
|
744
|
-
assert bd1c_per_t_packing == subc_quant_wsz
|
|
745
|
-
if w3_scale_vmem is not None:
|
|
746
|
-
assert w3_scale_vmem.shape == (
|
|
747
|
-
t_packing,
|
|
748
|
-
bd1_per_t_packing // subc_quant_wsz,
|
|
749
|
-
1,
|
|
750
|
-
bf,
|
|
751
|
-
)
|
|
752
|
-
assert bd1c_per_t_packing == subc_quant_wsz
|
|
753
535
|
|
|
754
536
|
num_loops = cdiv(dyn_sz, btc)
|
|
755
537
|
repack_ty = jnp.dtype(f"int{t_bitwidth}")
|
|
@@ -758,7 +540,7 @@ def _fused_ep_moe_kernel(
|
|
|
758
540
|
for bd1c_id in range(cdiv(bd1, bd1c)):
|
|
759
541
|
t_b32 = t_b32_vmem[
|
|
760
542
|
pl.ds(btc_id * btc, btc),
|
|
761
|
-
pl.ds(bd1c_id *
|
|
543
|
+
pl.ds(bd1c_id * bd1c_per_packing, bd1c_per_packing),
|
|
762
544
|
]
|
|
763
545
|
for p_id in range(t_packing):
|
|
764
546
|
t = pltpu.bitcast(t_b32.astype(repack_ty), t_dtype)
|
|
@@ -766,64 +548,21 @@ def _fused_ep_moe_kernel(
|
|
|
766
548
|
for bfc_id in range(cdiv(bf, bfc)):
|
|
767
549
|
w_slices = (
|
|
768
550
|
p_id,
|
|
769
|
-
pl.ds(bd1c_id *
|
|
770
|
-
|
|
551
|
+
pl.ds(bd1c_id * bd1c_per_packing,
|
|
552
|
+
bd1c_per_packing),
|
|
771
553
|
pl.ds(bfc_id * bfc, bfc),
|
|
772
554
|
)
|
|
773
555
|
w1 = w1_vmem[*w_slices]
|
|
774
556
|
acc1 = jnp.dot(t,
|
|
775
557
|
w1,
|
|
776
558
|
preferred_element_type=jnp.float32)
|
|
777
|
-
|
|
778
|
-
if w1_scale_vmem is not None:
|
|
779
|
-
w1_scale_slices = (
|
|
780
|
-
p_id,
|
|
781
|
-
bd1c_id,
|
|
782
|
-
pl.ds(0, 1),
|
|
783
|
-
pl.ds(bfc_id * bfc, bfc),
|
|
784
|
-
)
|
|
785
|
-
# TODO(jevinjiang): can use mosaic to load with stride 0.
|
|
786
|
-
w1_scale = jnp.broadcast_to(
|
|
787
|
-
w1_scale_vmem[*w1_scale_slices], acc1.shape)
|
|
788
|
-
acc1 *= w1_scale
|
|
789
|
-
|
|
790
559
|
w3 = w3_vmem[*w_slices]
|
|
791
|
-
|
|
792
560
|
acc3 = jnp.dot(t,
|
|
793
561
|
w3,
|
|
794
562
|
preferred_element_type=jnp.float32)
|
|
795
|
-
|
|
796
|
-
if w3_scale_vmem is not None:
|
|
797
|
-
w3_scale_slices = (
|
|
798
|
-
p_id,
|
|
799
|
-
bd1c_id,
|
|
800
|
-
pl.ds(0, 1),
|
|
801
|
-
pl.ds(bfc_id * bfc, bfc),
|
|
802
|
-
)
|
|
803
|
-
w3_scale = jnp.broadcast_to(
|
|
804
|
-
w3_scale_vmem[*w3_scale_slices], acc3.shape)
|
|
805
|
-
acc3 *= w3_scale
|
|
806
|
-
|
|
807
563
|
acc_slices = (pl.ds(btc_id * btc,
|
|
808
564
|
btc), pl.ds(bfc_id * bfc, bfc))
|
|
809
565
|
if should_init and p_id == bd1c_id == 0:
|
|
810
|
-
if b1_vmem is not None:
|
|
811
|
-
b1_scale_slices = (
|
|
812
|
-
pl.ds(0, 1),
|
|
813
|
-
pl.ds(bfc_id * bfc, bfc),
|
|
814
|
-
)
|
|
815
|
-
b1 = jnp.broadcast_to(
|
|
816
|
-
b1_vmem[*b1_scale_slices], acc1.shape)
|
|
817
|
-
acc1 += b1
|
|
818
|
-
if b3_vmem is not None:
|
|
819
|
-
b3_scale_slices = (
|
|
820
|
-
pl.ds(0, 1),
|
|
821
|
-
pl.ds(bfc_id * bfc, bfc),
|
|
822
|
-
)
|
|
823
|
-
b3 = jnp.broadcast_to(
|
|
824
|
-
b3_vmem[*b3_scale_slices], acc1.shape)
|
|
825
|
-
acc3 += b3
|
|
826
|
-
|
|
827
566
|
acc1_vmem[*acc_slices] = acc1
|
|
828
567
|
acc3_vmem[*acc_slices] = acc3
|
|
829
568
|
else:
|
|
@@ -836,28 +575,22 @@ def _fused_ep_moe_kernel(
|
|
|
836
575
|
acc1_vmem,
|
|
837
576
|
acc3_vmem,
|
|
838
577
|
w2_vmem,
|
|
839
|
-
w2_scale_vmem,
|
|
840
|
-
b2_vmem,
|
|
841
578
|
res_b32_vmem,
|
|
842
579
|
dyn_sz,
|
|
843
580
|
should_init,
|
|
844
581
|
):
|
|
845
|
-
assert res_b32_vmem.shape == (bt * num_devices,
|
|
846
|
-
assert w2_vmem.shape == (t_packing, bf,
|
|
582
|
+
assert res_b32_vmem.shape == (bt * num_devices, bd2_per_packing)
|
|
583
|
+
assert w2_vmem.shape == (t_packing, bf, bd2_per_packing), (
|
|
584
|
+
w2_vmem.shape,
|
|
585
|
+
t_packing,
|
|
586
|
+
bf,
|
|
587
|
+
bd2_per_packing,
|
|
588
|
+
)
|
|
847
589
|
assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf)
|
|
848
590
|
assert bd2 % (t_packing * 128) == 0, (bd2, t_packing)
|
|
849
591
|
assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing)
|
|
850
592
|
assert t_dtype in (jnp.float32, jnp.bfloat16)
|
|
851
593
|
|
|
852
|
-
if w2_scale_vmem is not None:
|
|
853
|
-
assert w2_scale_vmem.shape == (
|
|
854
|
-
t_packing,
|
|
855
|
-
bf // subc_quant_wsz,
|
|
856
|
-
1,
|
|
857
|
-
bd2_per_t_packing,
|
|
858
|
-
)
|
|
859
|
-
assert bfc == subc_quant_wsz
|
|
860
|
-
|
|
861
594
|
num_loops = cdiv(dyn_sz, btc)
|
|
862
595
|
assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing)
|
|
863
596
|
|
|
@@ -865,47 +598,22 @@ def _fused_ep_moe_kernel(
|
|
|
865
598
|
for bd2c_id in range(cdiv(bd2, bd2c)):
|
|
866
599
|
res_lst = []
|
|
867
600
|
for p_id in range(t_packing):
|
|
868
|
-
res = jnp.zeros((btc,
|
|
869
|
-
dtype=jnp.float32)
|
|
870
|
-
|
|
871
|
-
if b2_vmem is not None and should_init:
|
|
872
|
-
b2_scale_slices = (
|
|
873
|
-
p_id,
|
|
874
|
-
pl.ds(0, 1),
|
|
875
|
-
pl.ds(bd2c_id * bd2c_per_t_packing,
|
|
876
|
-
bd2c_per_t_packing),
|
|
877
|
-
)
|
|
878
|
-
b2 = jnp.broadcast_to(b2_vmem[*b2_scale_slices],
|
|
879
|
-
res.shape)
|
|
880
|
-
res += b2
|
|
881
|
-
|
|
601
|
+
res = jnp.zeros((btc, bd2c_per_packing), dtype=jnp.float32)
|
|
882
602
|
for bfc_id in range(cdiv(bf, bfc)):
|
|
883
603
|
acc_slices = (pl.ds(btc_id * btc,
|
|
884
604
|
btc), pl.ds(bfc_id * bfc, bfc))
|
|
885
605
|
acc1 = acc1_vmem[*acc_slices]
|
|
886
606
|
acc3 = acc3_vmem[*acc_slices]
|
|
887
|
-
act =
|
|
607
|
+
act = jax.nn.silu(acc1) * acc3
|
|
888
608
|
w2 = w2_vmem[
|
|
889
609
|
p_id,
|
|
890
610
|
pl.ds(bfc_id * bfc, bfc),
|
|
891
611
|
pl.ds(bd2c_id *
|
|
892
|
-
|
|
612
|
+
bd2c_per_packing, bd2c_per_packing),
|
|
893
613
|
]
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
if w2_scale_vmem is not None:
|
|
898
|
-
w2_scale_slices = (
|
|
899
|
-
p_id,
|
|
900
|
-
bfc_id,
|
|
901
|
-
pl.ds(0, 1),
|
|
902
|
-
pl.ds(bd2c_id * bd2c_per_t_packing,
|
|
903
|
-
bd2c_per_t_packing),
|
|
904
|
-
)
|
|
905
|
-
w2_scale = jnp.broadcast_to(
|
|
906
|
-
w2_scale_vmem[*w2_scale_slices], acc.shape)
|
|
907
|
-
acc *= w2_scale
|
|
908
|
-
res += acc
|
|
614
|
+
res += jnp.dot(act,
|
|
615
|
+
w2,
|
|
616
|
+
preferred_element_type=jnp.float32)
|
|
909
617
|
res = pltpu.bitcast(res, jnp.uint32)
|
|
910
618
|
if t_packing == 2:
|
|
911
619
|
res = res >> 16 << (16 * p_id)
|
|
@@ -918,7 +626,7 @@ def _fused_ep_moe_kernel(
|
|
|
918
626
|
res |= res_lst[i]
|
|
919
627
|
sliced_res_vmem = res_b32_vmem.at[
|
|
920
628
|
pl.ds(btc_id * btc, btc),
|
|
921
|
-
pl.ds(bd2c_id *
|
|
629
|
+
pl.ds(bd2c_id * bd2c_per_packing, bd2c_per_packing),
|
|
922
630
|
]
|
|
923
631
|
if should_init:
|
|
924
632
|
sliced_res_vmem[...] = res
|
|
@@ -947,33 +655,21 @@ def _fused_ep_moe_kernel(
|
|
|
947
655
|
e_id = my_id * local_num_experts + local_e_id
|
|
948
656
|
dyn_sz = expert_sizes_x2_smem[bt_sem_id, 0, e_id]
|
|
949
657
|
|
|
950
|
-
|
|
951
|
-
|
|
658
|
+
bd1_per_packing = bd1 // t_packing
|
|
659
|
+
bd2_per_packing = bd2 // t_packing
|
|
952
660
|
|
|
953
661
|
for bf_id in range(num_bf):
|
|
954
662
|
for bd1_id in range(num_bd1):
|
|
955
663
|
start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, 0)
|
|
956
|
-
w1_scale_vmem = (None if b_w1_scale_x2_vmem is None else
|
|
957
|
-
b_w1_scale_x2_vmem.at[bw_sem_id])
|
|
958
|
-
w3_scale_vmem = (None if b_w3_scale_x2_vmem is None else
|
|
959
|
-
b_w3_scale_x2_vmem.at[bw_sem_id])
|
|
960
|
-
b1_vmem = None if b_b1_x2_vmem is None else b_b1_x2_vmem.at[
|
|
961
|
-
bf_id % 2]
|
|
962
|
-
b3_vmem = None if b_b3_x2_vmem is None else b_b3_x2_vmem.at[
|
|
963
|
-
bf_id % 2]
|
|
964
664
|
wait_fetch_bw1(local_e_id, bw_sem_id, bf_id, bd1_id)
|
|
965
665
|
wait_fetch_bw3(local_e_id, bw_sem_id, bf_id, bd1_id)
|
|
966
666
|
|
|
967
667
|
dynamic_ffn1(
|
|
968
668
|
t_b32_vmem=a2a_s_b32_vmem.at[
|
|
969
669
|
...,
|
|
970
|
-
pl.ds(bd1_id *
|
|
670
|
+
pl.ds(bd1_id * bd1_per_packing, bd1_per_packing)],
|
|
971
671
|
w1_vmem=b_w1_x2_vmem.at[bw_sem_id],
|
|
972
|
-
w1_scale_vmem=w1_scale_vmem,
|
|
973
|
-
b1_vmem=b1_vmem,
|
|
974
672
|
w3_vmem=b_w3_x2_vmem.at[bw_sem_id],
|
|
975
|
-
w3_scale_vmem=w3_scale_vmem,
|
|
976
|
-
b3_vmem=b3_vmem,
|
|
977
673
|
acc1_vmem=b_acc1_vmem,
|
|
978
674
|
acc3_vmem=b_acc3_vmem,
|
|
979
675
|
dyn_sz=dyn_sz,
|
|
@@ -988,19 +684,13 @@ def _fused_ep_moe_kernel(
|
|
|
988
684
|
if bf_id == bd2_id == 0:
|
|
989
685
|
wait_a2a_gather_send(bt_id, e_sem_id, local_e_id - 2)
|
|
990
686
|
|
|
991
|
-
w2_scale_vmem = (None if b_w2_scale_x2_vmem is None else
|
|
992
|
-
b_w2_scale_x2_vmem.at[bw_sem_id])
|
|
993
|
-
b2_vmem = None if b_b2_x2_vmem is None else b_b2_x2_vmem.at[
|
|
994
|
-
bd2_id % 2]
|
|
995
687
|
dynamic_ffn2(
|
|
996
688
|
acc1_vmem=b_acc1_vmem,
|
|
997
689
|
acc3_vmem=b_acc3_vmem,
|
|
998
690
|
w2_vmem=b_w2_x2_vmem.at[bw_sem_id],
|
|
999
|
-
w2_scale_vmem=w2_scale_vmem,
|
|
1000
|
-
b2_vmem=b2_vmem,
|
|
1001
691
|
res_b32_vmem=a2a_s_acc_b32_vmem.at[
|
|
1002
692
|
...,
|
|
1003
|
-
pl.ds(bd2_id *
|
|
693
|
+
pl.ds(bd2_id * bd2_per_packing, bd2_per_packing)],
|
|
1004
694
|
dyn_sz=dyn_sz,
|
|
1005
695
|
should_init=(bf_id == 0),
|
|
1006
696
|
)
|
|
@@ -1067,7 +757,7 @@ def _fused_ep_moe_kernel(
|
|
|
1067
757
|
b_gating = b_gating_x2_vmem[bt_sem_id]
|
|
1068
758
|
b_gating_score = jax.nn.softmax(b_gating, axis=-1)
|
|
1069
759
|
top_k_logits_lst, t2e_routing, expert_sizes, expert_starts = get_top_k(
|
|
1070
|
-
b_gating_score, top_k
|
|
760
|
+
b_gating_score, top_k)
|
|
1071
761
|
|
|
1072
762
|
all_reduce_metadata(bt_sem_id, t2e_routing, expert_starts,
|
|
1073
763
|
expert_sizes)
|
|
@@ -1137,9 +827,6 @@ def _fused_ep_moe_kernel(
|
|
|
1137
827
|
static_argnames=[
|
|
1138
828
|
"mesh",
|
|
1139
829
|
"top_k",
|
|
1140
|
-
"renormalize_topk_logits",
|
|
1141
|
-
"act_fn",
|
|
1142
|
-
"subc_quant_wsz",
|
|
1143
830
|
"bt",
|
|
1144
831
|
"bf",
|
|
1145
832
|
"bd1",
|
|
@@ -1158,18 +845,7 @@ def fused_ep_moe(
|
|
|
1158
845
|
w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
|
|
1159
846
|
gating_output: jax.Array, # (num_tokens, num_experts)
|
|
1160
847
|
top_k: int,
|
|
1161
|
-
renormalize_topk_logits: bool = False,
|
|
1162
|
-
act_fn: str = "silu",
|
|
1163
848
|
*,
|
|
1164
|
-
subc_quant_wsz: int | None = None,
|
|
1165
|
-
w1_scale: (
|
|
1166
|
-
jax.Array | None
|
|
1167
|
-
) = None, # (num_experts, 2, cdiv(hidden_size, subc_quant_wsz), intermediate_size)
|
|
1168
|
-
w2_scale: (
|
|
1169
|
-
jax.Array | None
|
|
1170
|
-
) = None, # (num_experts, cdiv(intermediate_size, subc_quant_wsz), hidden_size)
|
|
1171
|
-
b1: jax.Array | None = None, # (num_experts, 2, intermediate_size)
|
|
1172
|
-
b2: jax.Array | None = None, # (num_experts, hidden_size)
|
|
1173
849
|
# Kernel tuning parameters.
|
|
1174
850
|
bt: int,
|
|
1175
851
|
bf: int,
|
|
@@ -1179,19 +855,18 @@ def fused_ep_moe(
|
|
|
1179
855
|
bfc: int,
|
|
1180
856
|
bd1c: int,
|
|
1181
857
|
bd2c: int,
|
|
1182
|
-
ep_axis_name: str =
|
|
858
|
+
ep_axis_name: str = 'model',
|
|
1183
859
|
):
|
|
1184
|
-
# TODO(jevinjiang): move all these assertions to validation function.
|
|
1185
860
|
# Assert all other axes have length of 1
|
|
1186
|
-
assert len(mesh.shape) == 2, "Expect 2D mesh"
|
|
1187
|
-
assert
|
|
1188
|
-
|
|
861
|
+
assert len(mesh.shape) == 2, "Expect 2D mesh in tpu-inference"
|
|
862
|
+
assert 'data' in mesh.shape and mesh.shape['data'] == 1, \
|
|
863
|
+
"Expect data axis size of 1 in tpu-inference"
|
|
1189
864
|
|
|
1190
865
|
ep_size = mesh.shape[ep_axis_name]
|
|
1191
866
|
num_devices = ep_size
|
|
1192
867
|
|
|
1193
868
|
num_tokens, actual_hidden_size = tokens.shape
|
|
1194
|
-
num_experts,
|
|
869
|
+
num_experts, intermediate_size, _ = w2.shape
|
|
1195
870
|
|
|
1196
871
|
assert num_tokens % ep_size == 0
|
|
1197
872
|
assert num_experts % ep_size == 0
|
|
@@ -1199,18 +874,26 @@ def fused_ep_moe(
|
|
|
1199
874
|
local_num_tokens = num_tokens // ep_size
|
|
1200
875
|
# local_num_experts = num_experts // ep_size
|
|
1201
876
|
padded_num_experts = align_to(num_experts, 128)
|
|
877
|
+
|
|
1202
878
|
t_dtype = tokens.dtype
|
|
1203
879
|
t_packing = get_dtype_packing(t_dtype)
|
|
880
|
+
hidden_size = align_to(actual_hidden_size, 128 * t_packing)
|
|
881
|
+
if hidden_size != actual_hidden_size:
|
|
882
|
+
tokens = jnp.pad(
|
|
883
|
+
tokens,
|
|
884
|
+
((0, 0), (0, hidden_size - actual_hidden_size)),
|
|
885
|
+
constant_values=0,
|
|
886
|
+
)
|
|
887
|
+
tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing)
|
|
888
|
+
bt = min(bt, local_num_tokens)
|
|
889
|
+
bf = min(bf, intermediate_size)
|
|
890
|
+
bd1 = min(bd1, hidden_size)
|
|
891
|
+
bd2 = min(bd2, hidden_size)
|
|
1204
892
|
|
|
1205
|
-
|
|
1206
|
-
|
|
1207
|
-
|
|
1208
|
-
|
|
1209
|
-
# We force compute size of contracting dim to subc_quant_wsz. So we can
|
|
1210
|
-
# apply same scale after matmul and accumulation.
|
|
1211
|
-
bd1c = subc_quant_wsz * t_packing
|
|
1212
|
-
bfc = subc_quant_wsz
|
|
1213
|
-
|
|
893
|
+
btc = min(btc, bt * num_devices)
|
|
894
|
+
bfc = min(bfc, bf)
|
|
895
|
+
bd1c = min(bd1c, bd1)
|
|
896
|
+
bd2c = min(bd2c, bd2)
|
|
1214
897
|
assert bfc % 128 == 0
|
|
1215
898
|
assert bd1c % (t_packing * 128) == 0
|
|
1216
899
|
assert bd2c % (t_packing * 128) == 0
|
|
@@ -1218,41 +901,6 @@ def fused_ep_moe(
|
|
|
1218
901
|
assert bd1 % bd1c == 0
|
|
1219
902
|
assert bd2 % bd2c == 0
|
|
1220
903
|
|
|
1221
|
-
btc = min(btc, bt * num_devices)
|
|
1222
|
-
hidden_size = align_to(actual_hidden_size, 128 * t_packing)
|
|
1223
|
-
# TODO(jevinjiang): instead of padding outside the kernel, we can try dynammic
|
|
1224
|
-
# masking inside the kernel.
|
|
1225
|
-
hidden_size = align_to(hidden_size, bd1)
|
|
1226
|
-
hidden_size = align_to(hidden_size, bd2)
|
|
1227
|
-
intermediate_size = align_to(actual_intermediate_size, bf)
|
|
1228
|
-
|
|
1229
|
-
# TODO(jevinjiang): we should dump scale as the kernel expected shape in the
|
|
1230
|
-
# checkpoint offline or reshape right after weight loading.
|
|
1231
|
-
if w1_scale is not None:
|
|
1232
|
-
assert w1_scale.shape[0] == w1.shape[0]
|
|
1233
|
-
assert w1_scale.shape[1] == w1.shape[1] == 2
|
|
1234
|
-
assert w1_scale.shape[2] == cdiv(w1.shape[2], subc_quant_wsz)
|
|
1235
|
-
assert w1_scale.shape[3] == w1.shape[3]
|
|
1236
|
-
w1_scale = jnp.expand_dims(w1_scale.astype(jnp.float32), axis=-2)
|
|
1237
|
-
|
|
1238
|
-
if w2_scale is not None:
|
|
1239
|
-
assert w2_scale.shape[0] == w2.shape[0]
|
|
1240
|
-
assert w2_scale.shape[1] == cdiv(w2.shape[1], subc_quant_wsz)
|
|
1241
|
-
assert w2_scale.shape[2] == w2.shape[2]
|
|
1242
|
-
w2_scale = jnp.expand_dims(w2_scale.astype(jnp.float32), axis=-2)
|
|
1243
|
-
|
|
1244
|
-
if b1 is not None:
|
|
1245
|
-
assert b1.shape[0] == w1.shape[0]
|
|
1246
|
-
assert b1.shape[1] == w1.shape[1] == 2
|
|
1247
|
-
assert b1.shape[2] == w1.shape[3]
|
|
1248
|
-
b1 = jnp.expand_dims(b1.astype(jnp.float32), axis=-2)
|
|
1249
|
-
|
|
1250
|
-
if b2 is not None:
|
|
1251
|
-
assert b2.shape[0] == w2.shape[0]
|
|
1252
|
-
assert b2.shape[1] == w2.shape[2]
|
|
1253
|
-
b2 = jnp.expand_dims(b2.astype(jnp.float32), axis=-2)
|
|
1254
|
-
|
|
1255
|
-
# Prepare inputs for the kernel.
|
|
1256
904
|
if padded_num_experts != gating_output.shape[-1]:
|
|
1257
905
|
gating_output = jnp.pad(
|
|
1258
906
|
gating_output,
|
|
@@ -1260,92 +908,13 @@ def fused_ep_moe(
|
|
|
1260
908
|
constant_values=-jnp.inf,
|
|
1261
909
|
)
|
|
1262
910
|
|
|
1263
|
-
|
|
1264
|
-
or intermediate_size != actual_intermediate_size):
|
|
1265
|
-
tokens = jnp.pad(
|
|
1266
|
-
tokens,
|
|
1267
|
-
((0, 0), (0, hidden_size - actual_hidden_size)),
|
|
1268
|
-
constant_values=0,
|
|
1269
|
-
)
|
|
1270
|
-
w1 = jnp.pad(
|
|
1271
|
-
w1,
|
|
1272
|
-
(
|
|
1273
|
-
(0, 0),
|
|
1274
|
-
(0, 0),
|
|
1275
|
-
(0, hidden_size - actual_hidden_size),
|
|
1276
|
-
(0, intermediate_size - actual_intermediate_size),
|
|
1277
|
-
),
|
|
1278
|
-
constant_values=0,
|
|
1279
|
-
)
|
|
1280
|
-
w2 = jnp.pad(
|
|
1281
|
-
w2,
|
|
1282
|
-
(
|
|
1283
|
-
(0, 0),
|
|
1284
|
-
(0, intermediate_size - actual_intermediate_size),
|
|
1285
|
-
(0, hidden_size - actual_hidden_size),
|
|
1286
|
-
),
|
|
1287
|
-
constant_values=0,
|
|
1288
|
-
)
|
|
1289
|
-
if w1_scale is not None:
|
|
1290
|
-
w1_scale = jnp.pad(
|
|
1291
|
-
w1_scale,
|
|
1292
|
-
(
|
|
1293
|
-
(0, 0),
|
|
1294
|
-
(0, 0),
|
|
1295
|
-
(0,
|
|
1296
|
-
cdiv(hidden_size, subc_quant_wsz) - w1_scale.shape[-3]),
|
|
1297
|
-
(0, 0),
|
|
1298
|
-
(0, intermediate_size - w1_scale.shape[-1]),
|
|
1299
|
-
),
|
|
1300
|
-
constant_values=0,
|
|
1301
|
-
)
|
|
1302
|
-
if w2_scale is not None:
|
|
1303
|
-
w2_scale = jnp.pad(
|
|
1304
|
-
w2_scale,
|
|
1305
|
-
(
|
|
1306
|
-
(0, 0),
|
|
1307
|
-
(0, cdiv(intermediate_size, subc_quant_wsz) -
|
|
1308
|
-
w2_scale.shape[-3]),
|
|
1309
|
-
(0, 0),
|
|
1310
|
-
(0, hidden_size - w2_scale.shape[-1]),
|
|
1311
|
-
),
|
|
1312
|
-
constant_values=0,
|
|
1313
|
-
)
|
|
1314
|
-
if b1 is not None:
|
|
1315
|
-
b1 = jnp.pad(
|
|
1316
|
-
b1,
|
|
1317
|
-
(
|
|
1318
|
-
(0, 0),
|
|
1319
|
-
(0, 0),
|
|
1320
|
-
(0, 0),
|
|
1321
|
-
(0, intermediate_size - b1.shape[-1]),
|
|
1322
|
-
),
|
|
1323
|
-
constant_values=0,
|
|
1324
|
-
)
|
|
1325
|
-
if b2 is not None:
|
|
1326
|
-
b2 = jnp.pad(
|
|
1327
|
-
b2,
|
|
1328
|
-
(
|
|
1329
|
-
(0, 0),
|
|
1330
|
-
(0, 0),
|
|
1331
|
-
(0, hidden_size - b2.shape[-1]),
|
|
1332
|
-
),
|
|
1333
|
-
constant_values=0,
|
|
1334
|
-
)
|
|
1335
|
-
|
|
1336
|
-
tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing)
|
|
1337
|
-
|
|
1338
|
-
hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM)
|
|
1339
|
-
scope_name = f"fused_moe_k-{top_k}_renorm-{renormalize_topk_logits}_bt-{bt}-{btc}_bf-{bf}-{bfc}_bd1-{bd1}-{bd1c}_bd2-{bd2}-{bd2c}"
|
|
911
|
+
scope_name = f"fused_moe_k-{top_k}_bt-{bt}-{btc}_bf-{bf}-{bfc}_bd1-{bd1}-{bd1c}_bd2-{bd2}-{bd2c}"
|
|
1340
912
|
fused_moe = jax.named_scope(scope_name)(
|
|
1341
913
|
pl.pallas_call(
|
|
1342
914
|
functools.partial(
|
|
1343
915
|
_fused_ep_moe_kernel,
|
|
1344
916
|
top_k=top_k,
|
|
1345
|
-
renormalize_topk_logits=renormalize_topk_logits,
|
|
1346
917
|
ep_axis_name=ep_axis_name,
|
|
1347
|
-
act_fn=act_fn,
|
|
1348
|
-
subc_quant_wsz=subc_quant_wsz,
|
|
1349
918
|
bt=bt,
|
|
1350
919
|
bf=bf,
|
|
1351
920
|
bd1=bd1,
|
|
@@ -1360,17 +929,11 @@ def fused_ep_moe(
|
|
|
1360
929
|
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
1361
930
|
num_scalar_prefetch=0,
|
|
1362
931
|
in_specs=[
|
|
1363
|
-
|
|
1364
|
-
|
|
1365
|
-
|
|
1366
|
-
|
|
1367
|
-
|
|
1368
|
-
None
|
|
1369
|
-
if w2_scale is None else hbm_block_spec, # w2_scale_hbm
|
|
1370
|
-
None if b1 is None else hbm_block_spec, # b1_hbm
|
|
1371
|
-
None if b2 is None else hbm_block_spec, # b2_hbm
|
|
1372
|
-
hbm_block_spec, # gating_output_hbm
|
|
1373
|
-
hbm_block_spec, # a2a_g_hbm
|
|
932
|
+
pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
|
|
933
|
+
pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
|
|
934
|
+
pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
|
|
935
|
+
pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
|
|
936
|
+
pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
|
|
1374
937
|
],
|
|
1375
938
|
out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
|
|
1376
939
|
scratch_shapes=([
|
|
@@ -1421,67 +984,6 @@ def fused_ep_moe(
|
|
|
1421
984
|
pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
|
|
1422
985
|
# b_w2_x2_vmem
|
|
1423
986
|
pltpu.VMEM((2, t_packing, bf, bd2 // t_packing), w2.dtype),
|
|
1424
|
-
# b_w1_scale_x2_vmem
|
|
1425
|
-
(None if w1_scale is None else pltpu.VMEM(
|
|
1426
|
-
(
|
|
1427
|
-
2,
|
|
1428
|
-
t_packing,
|
|
1429
|
-
bd1 // t_packing // subc_quant_wsz,
|
|
1430
|
-
1,
|
|
1431
|
-
bf,
|
|
1432
|
-
),
|
|
1433
|
-
jnp.float32,
|
|
1434
|
-
)),
|
|
1435
|
-
# b_w3_scale_x2_vmem
|
|
1436
|
-
(None if w1_scale is None else pltpu.VMEM(
|
|
1437
|
-
(
|
|
1438
|
-
2,
|
|
1439
|
-
t_packing,
|
|
1440
|
-
bd1 // t_packing // subc_quant_wsz,
|
|
1441
|
-
1,
|
|
1442
|
-
bf,
|
|
1443
|
-
),
|
|
1444
|
-
jnp.float32,
|
|
1445
|
-
)),
|
|
1446
|
-
# b_w2_scale_x2_vmem
|
|
1447
|
-
(None if w2_scale is None else pltpu.VMEM(
|
|
1448
|
-
(
|
|
1449
|
-
2,
|
|
1450
|
-
t_packing,
|
|
1451
|
-
bf // subc_quant_wsz,
|
|
1452
|
-
1,
|
|
1453
|
-
bd2 // t_packing,
|
|
1454
|
-
),
|
|
1455
|
-
jnp.float32,
|
|
1456
|
-
)),
|
|
1457
|
-
# b_b1_x2_vmem
|
|
1458
|
-
(None if b1 is None else pltpu.VMEM(
|
|
1459
|
-
(
|
|
1460
|
-
2,
|
|
1461
|
-
1,
|
|
1462
|
-
bf,
|
|
1463
|
-
),
|
|
1464
|
-
jnp.float32,
|
|
1465
|
-
)),
|
|
1466
|
-
# b_b3_x2_vmem
|
|
1467
|
-
(None if b1 is None else pltpu.VMEM(
|
|
1468
|
-
(
|
|
1469
|
-
2,
|
|
1470
|
-
1,
|
|
1471
|
-
bf,
|
|
1472
|
-
),
|
|
1473
|
-
jnp.float32,
|
|
1474
|
-
)),
|
|
1475
|
-
# b_b2_x2_vmem
|
|
1476
|
-
(None if b2 is None else pltpu.VMEM(
|
|
1477
|
-
(
|
|
1478
|
-
2,
|
|
1479
|
-
t_packing,
|
|
1480
|
-
1,
|
|
1481
|
-
bd2 // t_packing,
|
|
1482
|
-
),
|
|
1483
|
-
jnp.float32,
|
|
1484
|
-
)),
|
|
1485
987
|
# b_acc_vmem
|
|
1486
988
|
pltpu.VMEM((bt * num_devices, 1, bf * 2), jnp.float32),
|
|
1487
989
|
# local_sems
|
|
@@ -1504,50 +1006,21 @@ def fused_ep_moe(
|
|
|
1504
1006
|
))
|
|
1505
1007
|
|
|
1506
1008
|
@jax.jit
|
|
1507
|
-
@
|
|
1009
|
+
@functools.partial(
|
|
1010
|
+
shard_map.shard_map,
|
|
1508
1011
|
mesh=mesh,
|
|
1509
|
-
in_specs=(
|
|
1510
|
-
|
|
1511
|
-
P(ep_axis_name), # w1_hbm
|
|
1512
|
-
P(ep_axis_name), # w2_hbm
|
|
1513
|
-
None if w1_scale is None else P(ep_axis_name), # w1_scale_hbm
|
|
1514
|
-
None if w2_scale is None else P(ep_axis_name), # w2_scale_hbm
|
|
1515
|
-
None if b1 is None else P(ep_axis_name), # b1_hbm
|
|
1516
|
-
None if b2 is None else P(ep_axis_name), # b2_hbm
|
|
1517
|
-
P(ep_axis_name), # gating_output_hbm
|
|
1518
|
-
P(), # a2a_g_hbm
|
|
1519
|
-
),
|
|
1012
|
+
in_specs=(P(ep_axis_name), P(ep_axis_name), P(ep_axis_name),
|
|
1013
|
+
P(ep_axis_name), P()),
|
|
1520
1014
|
out_specs=P(ep_axis_name),
|
|
1521
|
-
|
|
1015
|
+
check_rep=False,
|
|
1522
1016
|
)
|
|
1523
|
-
def kernel(
|
|
1524
|
-
tokens,
|
|
1525
|
-
w1,
|
|
1526
|
-
w2,
|
|
1527
|
-
w1_scale,
|
|
1528
|
-
w2_scale,
|
|
1529
|
-
b1,
|
|
1530
|
-
b2,
|
|
1531
|
-
gating_output,
|
|
1532
|
-
a2a_g_hbm_scratch,
|
|
1533
|
-
):
|
|
1017
|
+
def kernel(tokens, w1, w2, gating_output, a2a_g_hbm_scratch):
|
|
1534
1018
|
return fused_moe(
|
|
1535
|
-
pltpu.with_memory_space_constraint(tokens,
|
|
1536
|
-
|
|
1537
|
-
pltpu.with_memory_space_constraint(
|
|
1538
|
-
pltpu.with_memory_space_constraint(
|
|
1539
|
-
(
|
|
1540
|
-
w1_scale, pltpu.HBM)), # w1_scale_hbm
|
|
1541
|
-
(None if w2_scale is None else pltpu.with_memory_space_constraint(
|
|
1542
|
-
w2_scale, pltpu.HBM)), # w2_scale_hbm
|
|
1543
|
-
(None if b1 is None else pltpu.with_memory_space_constraint(
|
|
1544
|
-
b1, pltpu.HBM)), # b1_hbm
|
|
1545
|
-
(None if b2 is None else pltpu.with_memory_space_constraint(
|
|
1546
|
-
b2, pltpu.HBM)), # b2_hbm
|
|
1547
|
-
pltpu.with_memory_space_constraint(gating_output,
|
|
1548
|
-
pltpu.HBM), # gating_output_hbm
|
|
1549
|
-
pltpu.with_memory_space_constraint(a2a_g_hbm_scratch,
|
|
1550
|
-
pltpu.HBM), # a2a_g_hbm
|
|
1019
|
+
pltpu.with_memory_space_constraint(tokens, pltpu.HBM),
|
|
1020
|
+
pltpu.with_memory_space_constraint(w1, pltpu.HBM),
|
|
1021
|
+
pltpu.with_memory_space_constraint(w2, pltpu.HBM),
|
|
1022
|
+
pltpu.with_memory_space_constraint(gating_output, pltpu.HBM),
|
|
1023
|
+
pltpu.with_memory_space_constraint(a2a_g_hbm_scratch, pltpu.HBM),
|
|
1551
1024
|
)
|
|
1552
1025
|
|
|
1553
1026
|
a2a_g_hbm_scratch = pl.empty(
|
|
@@ -1556,10 +1029,6 @@ def fused_ep_moe(
|
|
|
1556
1029
|
tokens,
|
|
1557
1030
|
w1,
|
|
1558
1031
|
w2,
|
|
1559
|
-
w1_scale,
|
|
1560
|
-
w2_scale,
|
|
1561
|
-
b1,
|
|
1562
|
-
b2,
|
|
1563
1032
|
gating_output,
|
|
1564
1033
|
a2a_g_hbm_scratch,
|
|
1565
1034
|
)
|