tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202512030818__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tests/test_envs.py +32 -11
- tests/test_utils.py +1 -2
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +3 -4
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +61 -8
- tpu_inference/executors/ray_distributed_executor.py +31 -11
- tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +213 -126
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +74 -25
- tpu_inference/layers/vllm/quantization/common.py +6 -1
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -62
- tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +45 -11
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +3 -6
- tpu_inference/models/jax/utils/weight_utils.py +198 -143
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -7
- tpu_inference/platforms/tpu_platform.py +28 -22
- tpu_inference/runner/compilation_manager.py +144 -59
- tpu_inference/runner/kv_cache_manager.py +17 -18
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +271 -147
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -21
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +36 -13
- tpu_inference/worker/tpu_worker.py +162 -25
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/METADATA +3 -2
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/RECORD +48 -53
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202512030818.dist-info}/top_level.txt +0 -0
|
@@ -7,7 +7,6 @@ import jax.numpy as jnp
|
|
|
7
7
|
from jax import lax
|
|
8
8
|
from jax._src import dtypes
|
|
9
9
|
from jax.experimental import pallas as pl
|
|
10
|
-
from jax.experimental import shard_map
|
|
11
10
|
from jax.experimental.pallas import tpu as pltpu
|
|
12
11
|
|
|
13
12
|
P = jax.sharding.PartitionSpec
|
|
@@ -35,13 +34,49 @@ def broadcast_minor(src, shape):
|
|
|
35
34
|
axis=-1)[..., :shape[-1]]
|
|
36
35
|
|
|
37
36
|
|
|
37
|
+
def swigluoai(gate: jax.Array,
|
|
38
|
+
up: jax.Array,
|
|
39
|
+
*,
|
|
40
|
+
alpha: float = 1.702,
|
|
41
|
+
limit: float = 7.0) -> jax.Array:
|
|
42
|
+
"""Activation used in some models such as GPT-OSS."""
|
|
43
|
+
gate = jnp.clip(gate, a_max=limit)
|
|
44
|
+
up = jnp.clip(up, a_min=-limit, a_max=limit)
|
|
45
|
+
glu = gate * jax.nn.sigmoid(alpha * gate)
|
|
46
|
+
return (up + 1.0) * glu
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def activation_fn(acc1, acc3, act_fn):
|
|
50
|
+
if act_fn == "silu":
|
|
51
|
+
return jax.nn.silu(acc1) * acc3
|
|
52
|
+
elif act_fn == "gelu":
|
|
53
|
+
return jax.nn.gelu(acc1) * acc3
|
|
54
|
+
elif act_fn == "swigluoai":
|
|
55
|
+
return swigluoai(acc1, acc3)
|
|
56
|
+
else:
|
|
57
|
+
raise RuntimeError(f"Unsupported activation function: {act_fn}")
|
|
58
|
+
|
|
59
|
+
|
|
38
60
|
def ref_moe(
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
61
|
+
tokens: jax.Array, # (num_tokens, hidden_size)
|
|
62
|
+
w1: jax.Array, # (num_experts, 2, hidden_size, intermediate_size)
|
|
63
|
+
w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
|
|
64
|
+
gating_output: jax.Array, # (num_tokens, num_experts)
|
|
65
|
+
top_k: int,
|
|
66
|
+
*,
|
|
67
|
+
renormalize_topk_logits: bool = False,
|
|
68
|
+
activation="silu",
|
|
69
|
+
subc_quant_wsz: int | None = None,
|
|
70
|
+
w1_scale:
|
|
71
|
+
(
|
|
72
|
+
jax.Array | None
|
|
73
|
+
) = None, # (num_experts, 2, cdiv(hidden_size, subc_quant_wsz), intermediate_size)
|
|
74
|
+
w2_scale:
|
|
75
|
+
(
|
|
76
|
+
jax.Array | None
|
|
77
|
+
) = None, # (num_experts, cdiv(intermediate_size, subc_quant_wsz), hidden_size)
|
|
78
|
+
b1: jax.Array | None = None, # (num_experts, 2, intermediate_size)
|
|
79
|
+
b2: jax.Array | None = None, # (num_experts, hidden_size)
|
|
45
80
|
):
|
|
46
81
|
n_tokens = tokens.shape[0] # num_tokens
|
|
47
82
|
|
|
@@ -53,7 +88,12 @@ def ref_moe(
|
|
|
53
88
|
top_k_logits, top_k_indices = lax.top_k(
|
|
54
89
|
gating_logits, top_k) # [num_tokens, top_k], [num_tokens, top_k]
|
|
55
90
|
|
|
91
|
+
if renormalize_topk_logits:
|
|
92
|
+
top_k_logits = top_k_logits / jnp.sum(
|
|
93
|
+
top_k_logits, axis=-1, keepdims=True)
|
|
94
|
+
|
|
56
95
|
t_outputs = []
|
|
96
|
+
hidden_size, intermediate_size = w1.shape[-2:]
|
|
57
97
|
|
|
58
98
|
# Process each token individually
|
|
59
99
|
for i in range(n_tokens):
|
|
@@ -65,10 +105,24 @@ def ref_moe(
|
|
|
65
105
|
# Process each selected expert for the current token
|
|
66
106
|
for expert_id in assigned_expert_ids:
|
|
67
107
|
# Get expert weights
|
|
108
|
+
expert_w1 = w1[expert_id, 0].astype(jnp.float32)
|
|
109
|
+
expert_w3 = w1[expert_id, 1].astype(jnp.float32)
|
|
110
|
+
if w1_scale is not None:
|
|
111
|
+
expert_w1 *= jnp.repeat(w1_scale[expert_id, 0],
|
|
112
|
+
subc_quant_wsz,
|
|
113
|
+
axis=0)[:hidden_size]
|
|
114
|
+
expert_w3 *= jnp.repeat(w1_scale[expert_id, 1],
|
|
115
|
+
subc_quant_wsz,
|
|
116
|
+
axis=0)[:hidden_size]
|
|
68
117
|
expert_weight_1 = jnp.concat(
|
|
69
|
-
[
|
|
118
|
+
[expert_w1, expert_w3],
|
|
70
119
|
axis=-1) # [d_model, 2 * intermediate_size]
|
|
71
|
-
expert_weight_2 = w2[expert_id]
|
|
120
|
+
expert_weight_2 = w2[expert_id].astype(
|
|
121
|
+
jnp.float32) # [intermediate_size, d_model]
|
|
122
|
+
if w2_scale is not None:
|
|
123
|
+
expert_weight_2 *= jnp.repeat(w2_scale[expert_id],
|
|
124
|
+
subc_quant_wsz,
|
|
125
|
+
axis=0)[:intermediate_size]
|
|
72
126
|
|
|
73
127
|
# First linear layer with SwiGLU activation
|
|
74
128
|
gmm_1_out = curr_token @ expert_weight_1 # [1, 2 * intermediate_size]
|
|
@@ -77,21 +131,17 @@ def ref_moe(
|
|
|
77
131
|
gmm1_w1_proj, gmm1_w3_proj = jnp.split(
|
|
78
132
|
gmm_1_out, 2,
|
|
79
133
|
axis=-1) # [1, intermediate_size], [1, intermediate_size]
|
|
134
|
+
if b1 is not None:
|
|
135
|
+
gmm1_w1_proj += b1[expert_id:expert_id + 1, 0]
|
|
136
|
+
gmm1_w3_proj += b1[expert_id:expert_id + 1, 1]
|
|
80
137
|
|
|
81
138
|
# Apply gated activation: activation(gate) * up
|
|
82
|
-
|
|
83
|
-
act = jax.nn.silu(
|
|
84
|
-
gmm1_w1_proj) * gmm1_w3_proj # [1, intermediate_size]
|
|
85
|
-
elif activation == "gelu":
|
|
86
|
-
act = jax.nn.gelu(
|
|
87
|
-
gmm1_w1_proj) * gmm1_w3_proj # [1, intermediate_size]
|
|
88
|
-
else:
|
|
89
|
-
raise ValueError(
|
|
90
|
-
f"Unsupported activation: {activation}. Use 'silu' or 'gelu'."
|
|
91
|
-
)
|
|
139
|
+
act = activation_fn(gmm1_w1_proj, gmm1_w3_proj, activation)
|
|
92
140
|
|
|
93
141
|
# Second linear layer (down projection)
|
|
94
142
|
gmm_2_out = act @ expert_weight_2 # [1, d_model]
|
|
143
|
+
if b2 is not None:
|
|
144
|
+
gmm_2_out += b2[expert_id:expert_id + 1]
|
|
95
145
|
tok_expert_act.append(gmm_2_out)
|
|
96
146
|
|
|
97
147
|
# Combine outputs from all selected experts
|
|
@@ -105,7 +155,7 @@ def ref_moe(
|
|
|
105
155
|
axis=0,
|
|
106
156
|
keepdims=True) # [1, d_model]
|
|
107
157
|
|
|
108
|
-
t_outputs.append(weighted_output)
|
|
158
|
+
t_outputs.append(weighted_output.astype(tokens.dtype))
|
|
109
159
|
|
|
110
160
|
return jnp.concatenate(t_outputs, axis=0) # [num_tokens, d_model]
|
|
111
161
|
|
|
@@ -115,6 +165,13 @@ def _fused_ep_moe_kernel(
|
|
|
115
165
|
tokens_hbm, # (local_num_tokens, t_packing, hidden_size // t_packing)
|
|
116
166
|
w1_hbm, # (local_num_experts, 2, hidden_size, intermediate_size)
|
|
117
167
|
w2_hbm, # (local_num_experts, intermediate_size, hidden_size)
|
|
168
|
+
# TODO(jevinjiang): We choose F32 scale for easier slicing. The extra
|
|
169
|
+
# latency should be hidden in the pipeline overlaping. But is there a better
|
|
170
|
+
# way to do this?
|
|
171
|
+
w1_scale_hbm, # None | F32(local_num_experts, 2, cdiv(hidden_size, subc_quant_wsz), 1, intermediate_size)
|
|
172
|
+
w2_scale_hbm, # None | F32(local_num_experts, cdiv(intermediate_size, subc_quant_wsz), 1, hidden_size)
|
|
173
|
+
b1_hbm, # None | F32(local_num_experts, 2, 1, intermediate_size)
|
|
174
|
+
b2_hbm, # None | F32(local_num_experts, 1, hidden_size)
|
|
118
175
|
gating_hbm, # (local_num_tokens, padded_num_experts)
|
|
119
176
|
a2a_g_hbm, # (num_experts, bt, t_packing, hidden_size // t_packing)
|
|
120
177
|
# Output
|
|
@@ -136,6 +193,12 @@ def _fused_ep_moe_kernel(
|
|
|
136
193
|
b_w1_x2_vmem, # <bw_sem_id> (2, t_packing, bd1 // t_packing, bf)
|
|
137
194
|
b_w3_x2_vmem, # <bw_sem_id> (2, t_packing, bd1 // t_packing, bf)
|
|
138
195
|
b_w2_x2_vmem, # <bw_sem_id> (2, t_packing, bf, bd2 // t_packing)
|
|
196
|
+
b_w1_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf)
|
|
197
|
+
b_w3_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bd1 // t_packing // subc_quant_wsz, 1, bf)
|
|
198
|
+
b_w2_scale_x2_vmem, # None | <bw_sem_id> (2, t_packing, bf // subc_quant_wsz, 1, bd2 // t_packing)
|
|
199
|
+
b_b1_x2_vmem, # None | <bw_sem_id> (2, 1, bf)
|
|
200
|
+
b_b3_x2_vmem, # None | <bw_sem_id> (2, 1, bf)
|
|
201
|
+
b_b2_x2_vmem, # None | <bw_sem_id> (2, t_packing, 1, bd2 // t_packing)
|
|
139
202
|
b_acc_vmem, # F32(bt * num_devices, 1, bf * 2)
|
|
140
203
|
### Semaphores:
|
|
141
204
|
local_sems, # (2, 5): 2 x [b_gating_sem, b_w1_sem, b_w2_sem, b_w3_sem, b_output_sem]
|
|
@@ -145,7 +208,10 @@ def _fused_ep_moe_kernel(
|
|
|
145
208
|
a2a_acc_sem,
|
|
146
209
|
*,
|
|
147
210
|
top_k: int,
|
|
211
|
+
renormalize_topk_logits: bool,
|
|
148
212
|
ep_axis_name: str,
|
|
213
|
+
act_fn: str,
|
|
214
|
+
subc_quant_wsz: int | None = None,
|
|
149
215
|
# Kernel tuning params.
|
|
150
216
|
bt: int, # Block size of local_num_tokens.
|
|
151
217
|
bf: int, # Block size of intermediate_size.
|
|
@@ -160,34 +226,53 @@ def _fused_ep_moe_kernel(
|
|
|
160
226
|
num_devices = lax.axis_size(ep_axis_name)
|
|
161
227
|
local_num_tokens = tokens_hbm.shape[0]
|
|
162
228
|
local_num_experts, intermediate_size, hidden_size = w2_hbm.shape
|
|
163
|
-
# num_experts = local_num_experts * num_devices
|
|
164
|
-
# padded_num_experts = expert_starts_x2_smem.shape[-1]
|
|
165
229
|
right_id = (my_id + 1) % num_devices
|
|
166
230
|
|
|
167
231
|
t_dtype = tokens_hbm.dtype
|
|
168
232
|
t_packing = get_dtype_packing(t_dtype)
|
|
169
233
|
t_bitwidth = 32 // t_packing
|
|
170
234
|
assert a2a_g_hbm.dtype == t_dtype
|
|
171
|
-
assert w1_hbm.dtype ==
|
|
172
|
-
assert w2_hbm.dtype == t_dtype
|
|
235
|
+
assert w1_hbm.dtype == w2_hbm.dtype
|
|
173
236
|
|
|
174
|
-
|
|
175
|
-
assert
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
237
|
+
assert bd1 % bd1c == 0
|
|
238
|
+
assert bd2 % bd2c == 0
|
|
239
|
+
assert bf % bfc == 0
|
|
240
|
+
assert hidden_size % t_packing == 0
|
|
241
|
+
assert bd1 % t_packing == 0
|
|
242
|
+
assert bd2 % t_packing == 0
|
|
243
|
+
assert bd1c % t_packing == 0
|
|
244
|
+
assert bd2c % t_packing == 0
|
|
245
|
+
|
|
246
|
+
h_per_t_packing = hidden_size // t_packing
|
|
247
|
+
assert tokens_hbm.shape[-1] == h_per_t_packing
|
|
248
|
+
bd1_per_t_packing = bd1 // t_packing
|
|
249
|
+
bd2_per_t_packing = bd2 // t_packing
|
|
250
|
+
bd1c_per_t_packing = bd1c // t_packing
|
|
251
|
+
bd2c_per_t_packing = bd2c // t_packing
|
|
252
|
+
|
|
253
|
+
if subc_quant_wsz is not None:
|
|
254
|
+
assert subc_quant_wsz % 256 == 0
|
|
255
|
+
assert bd1c_per_t_packing == subc_quant_wsz
|
|
256
|
+
assert bfc == subc_quant_wsz
|
|
257
|
+
assert bd1 % subc_quant_wsz == 0
|
|
258
|
+
assert bf % subc_quant_wsz == 0
|
|
259
|
+
assert bd1_per_t_packing % subc_quant_wsz == 0
|
|
260
|
+
assert h_per_t_packing % subc_quant_wsz == 0
|
|
180
261
|
|
|
181
262
|
num_bt = cdiv(local_num_tokens, bt)
|
|
182
263
|
num_bf = cdiv(intermediate_size, bf)
|
|
183
264
|
num_bd1 = cdiv(hidden_size, bd1)
|
|
184
265
|
num_bd2 = cdiv(hidden_size, bd2)
|
|
185
266
|
|
|
267
|
+
def get_mesh_device_id(ep_rank):
|
|
268
|
+
dp_rank = jax.lax.axis_index("data")
|
|
269
|
+
return (dp_rank, ep_rank)
|
|
270
|
+
|
|
186
271
|
def sync_barrier():
|
|
187
272
|
barrier_sem = pltpu.get_barrier_semaphore()
|
|
188
273
|
pltpu.semaphore_signal(
|
|
189
274
|
barrier_sem,
|
|
190
|
-
device_id=(
|
|
275
|
+
device_id=get_mesh_device_id(right_id),
|
|
191
276
|
device_id_type=pltpu.DeviceIdType.MESH,
|
|
192
277
|
)
|
|
193
278
|
pltpu.semaphore_wait(barrier_sem, 1)
|
|
@@ -212,7 +297,7 @@ def _fused_ep_moe_kernel(
|
|
|
212
297
|
sem=b_gating_sem,
|
|
213
298
|
).wait()
|
|
214
299
|
|
|
215
|
-
def get_top_k(input, top_k):
|
|
300
|
+
def get_top_k(input, top_k, renormalize_topk_logits):
|
|
216
301
|
assert len(input.shape) == 2, input.shape
|
|
217
302
|
input = input.astype(jnp.float32)
|
|
218
303
|
top_k_logits_lst = []
|
|
@@ -220,11 +305,15 @@ def _fused_ep_moe_kernel(
|
|
|
220
305
|
t2e = jnp.zeros(input.shape, dtype=jnp.int32)
|
|
221
306
|
t2e_routing = jnp.zeros(input.shape, dtype=jnp.int32)
|
|
222
307
|
iota = jax.lax.broadcasted_iota(jnp.int32, input.shape, 1)
|
|
308
|
+
top_k_logits_sum = jnp.zeros((input.shape[0], 128), jnp.float32)
|
|
309
|
+
|
|
223
310
|
for k_id in range(top_k):
|
|
224
|
-
# TODO(jevinjiang): return both top_k values and indices in
|
|
311
|
+
# TODO(jevinjiang): return both top_k values and indices in Mosaic
|
|
225
312
|
top_k_logits = jnp.broadcast_to(
|
|
226
313
|
jnp.max(input, axis=1, keepdims=True),
|
|
227
314
|
(input.shape[0], 128)).astype(input.dtype)
|
|
315
|
+
if renormalize_topk_logits:
|
|
316
|
+
top_k_logits_sum += top_k_logits
|
|
228
317
|
top_k_logits_lst.append(top_k_logits)
|
|
229
318
|
# TODO(jevinjiang): support bf16 argmax in Mosaic
|
|
230
319
|
top_k_indices = jnp.broadcast_to(
|
|
@@ -236,6 +325,11 @@ def _fused_ep_moe_kernel(
|
|
|
236
325
|
if k_id != top_k - 1:
|
|
237
326
|
input = jnp.where(mask, -jnp.inf, input)
|
|
238
327
|
|
|
328
|
+
if renormalize_topk_logits:
|
|
329
|
+
for k_id in range(top_k):
|
|
330
|
+
top_k_logits_lst[
|
|
331
|
+
k_id] = top_k_logits_lst[k_id] / top_k_logits_sum
|
|
332
|
+
|
|
239
333
|
expert_sizes = jnp.sum(t2e, axis=0, keepdims=True)
|
|
240
334
|
expert_starts = jnp.zeros_like(expert_sizes)
|
|
241
335
|
return top_k_logits_lst, t2e_routing, expert_sizes, expert_starts
|
|
@@ -277,7 +371,7 @@ def _fused_ep_moe_kernel(
|
|
|
277
371
|
dst_ref=d2e_count_vmem.at[row_id],
|
|
278
372
|
send_sem=send_sem,
|
|
279
373
|
recv_sem=recv_sem,
|
|
280
|
-
device_id=(
|
|
374
|
+
device_id=get_mesh_device_id(right_id),
|
|
281
375
|
device_id_type=pltpu.DeviceIdType.MESH,
|
|
282
376
|
).wait()
|
|
283
377
|
row_id = (row_id + num_devices - 1) % num_devices
|
|
@@ -359,10 +453,8 @@ def _fused_ep_moe_kernel(
|
|
|
359
453
|
pl.ds(start, remote_sz)],
|
|
360
454
|
send_sem=send_sems.at[e_sem_id],
|
|
361
455
|
recv_sem=recv_sems.at[e_sem_id],
|
|
362
|
-
device_id=(
|
|
363
|
-
|
|
364
|
-
recv_id,
|
|
365
|
-
),
|
|
456
|
+
device_id=get_mesh_device_id(recv_id),
|
|
457
|
+
device_id_type=pltpu.DeviceIdType.MESH,
|
|
366
458
|
).start()
|
|
367
459
|
a2a_s_sends_x2_smem[e_sem_id] = send_sz
|
|
368
460
|
|
|
@@ -406,7 +498,8 @@ def _fused_ep_moe_kernel(
|
|
|
406
498
|
dst_ref=a2a_g_hbm.at[my_e_id, pl.ds(0, remote_sz)],
|
|
407
499
|
send_sem=send_sems.at[e_sem_id],
|
|
408
500
|
recv_sem=a2a_gather_sem,
|
|
409
|
-
device_id=(
|
|
501
|
+
device_id=get_mesh_device_id(recv_id),
|
|
502
|
+
device_id_type=pltpu.DeviceIdType.MESH,
|
|
410
503
|
).start()
|
|
411
504
|
start += sz
|
|
412
505
|
|
|
@@ -435,68 +528,173 @@ def _fused_ep_moe_kernel(
|
|
|
435
528
|
|
|
436
529
|
def start_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id):
|
|
437
530
|
for p in range(t_packing):
|
|
438
|
-
offset = p *
|
|
531
|
+
offset = p * h_per_t_packing + bd1_id * bd1_per_t_packing
|
|
439
532
|
pltpu.make_async_copy(
|
|
440
533
|
src_ref=w1_hbm.at[
|
|
441
534
|
local_e_id,
|
|
442
535
|
0,
|
|
443
|
-
pl.ds(offset,
|
|
536
|
+
pl.ds(offset, bd1_per_t_packing),
|
|
444
537
|
pl.ds(bf_id * bf, bf),
|
|
445
538
|
],
|
|
446
539
|
dst_ref=b_w1_x2_vmem.at[bw1_sem_id, p],
|
|
447
540
|
sem=local_sems.at[bw1_sem_id, 1],
|
|
448
541
|
).start()
|
|
542
|
+
if w1_scale_hbm is not None:
|
|
543
|
+
assert subc_quant_wsz is not None
|
|
544
|
+
pltpu.make_async_copy(
|
|
545
|
+
src_ref=w1_scale_hbm.at[
|
|
546
|
+
local_e_id,
|
|
547
|
+
0,
|
|
548
|
+
pl.ds(
|
|
549
|
+
offset // subc_quant_wsz,
|
|
550
|
+
bd1_per_t_packing // subc_quant_wsz,
|
|
551
|
+
),
|
|
552
|
+
pl.ds(0, 1),
|
|
553
|
+
pl.ds(bf_id * bf, bf),
|
|
554
|
+
],
|
|
555
|
+
dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id, p],
|
|
556
|
+
sem=local_sems.at[bw1_sem_id, 1],
|
|
557
|
+
).start()
|
|
558
|
+
if b1_hbm is not None and bd1_id == 0:
|
|
559
|
+
pltpu.make_async_copy(
|
|
560
|
+
src_ref=b1_hbm.at[local_e_id, 0,
|
|
561
|
+
pl.ds(0, 1),
|
|
562
|
+
pl.ds(bf_id * bf, bf)],
|
|
563
|
+
dst_ref=b_b1_x2_vmem.at[bf_id % 2],
|
|
564
|
+
sem=local_sems.at[bw1_sem_id, 1],
|
|
565
|
+
).start()
|
|
449
566
|
|
|
450
567
|
def start_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id):
|
|
451
568
|
for p in range(t_packing):
|
|
452
|
-
offset = p *
|
|
569
|
+
offset = p * h_per_t_packing + bd2_id * bd2_per_t_packing
|
|
453
570
|
pltpu.make_async_copy(
|
|
454
571
|
src_ref=w2_hbm.at[
|
|
455
572
|
local_e_id,
|
|
456
573
|
pl.ds(bf_id * bf, bf),
|
|
457
|
-
pl.ds(offset,
|
|
574
|
+
pl.ds(offset, bd2_per_t_packing),
|
|
458
575
|
],
|
|
459
576
|
dst_ref=b_w2_x2_vmem.at[bw2_sem_id, p],
|
|
460
577
|
sem=local_sems.at[bw2_sem_id, 2],
|
|
461
578
|
).start()
|
|
579
|
+
if w2_scale_hbm is not None:
|
|
580
|
+
assert subc_quant_wsz is not None
|
|
581
|
+
pltpu.make_async_copy(
|
|
582
|
+
src_ref=w2_scale_hbm.at[
|
|
583
|
+
local_e_id,
|
|
584
|
+
pl.ds(bf_id * bf // subc_quant_wsz, bf //
|
|
585
|
+
subc_quant_wsz),
|
|
586
|
+
pl.ds(0, 1),
|
|
587
|
+
pl.ds(offset, bd2_per_t_packing),
|
|
588
|
+
],
|
|
589
|
+
dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id, p],
|
|
590
|
+
sem=local_sems.at[bw2_sem_id, 2],
|
|
591
|
+
).start()
|
|
592
|
+
if b2_hbm is not None and bf_id == 0:
|
|
593
|
+
pltpu.make_async_copy(
|
|
594
|
+
src_ref=b2_hbm.at[local_e_id,
|
|
595
|
+
pl.ds(0, 1),
|
|
596
|
+
pl.ds(offset, bd2_per_t_packing)],
|
|
597
|
+
dst_ref=b_b2_x2_vmem.at[bd2_id % 2, p],
|
|
598
|
+
sem=local_sems.at[bw2_sem_id, 2],
|
|
599
|
+
).start()
|
|
462
600
|
|
|
463
601
|
def start_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id):
|
|
464
602
|
for p in range(t_packing):
|
|
465
|
-
offset = p *
|
|
603
|
+
offset = p * h_per_t_packing + bd3_id * bd1_per_t_packing
|
|
466
604
|
pltpu.make_async_copy(
|
|
467
605
|
src_ref=w1_hbm.at[
|
|
468
606
|
local_e_id,
|
|
469
607
|
1,
|
|
470
|
-
pl.ds(offset,
|
|
608
|
+
pl.ds(offset, bd1_per_t_packing),
|
|
471
609
|
pl.ds(bf_id * bf, bf),
|
|
472
610
|
],
|
|
473
611
|
dst_ref=b_w3_x2_vmem.at[bw3_sem_id, p],
|
|
474
612
|
sem=local_sems.at[bw3_sem_id, 3],
|
|
475
613
|
).start()
|
|
614
|
+
if w1_scale_hbm is not None:
|
|
615
|
+
assert subc_quant_wsz is not None
|
|
616
|
+
pltpu.make_async_copy(
|
|
617
|
+
src_ref=w1_scale_hbm.at[
|
|
618
|
+
local_e_id,
|
|
619
|
+
1,
|
|
620
|
+
pl.ds(
|
|
621
|
+
offset // subc_quant_wsz,
|
|
622
|
+
bd1_per_t_packing // subc_quant_wsz,
|
|
623
|
+
),
|
|
624
|
+
pl.ds(0, 1),
|
|
625
|
+
pl.ds(bf_id * bf, bf),
|
|
626
|
+
],
|
|
627
|
+
dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id, p],
|
|
628
|
+
sem=local_sems.at[bw3_sem_id, 3],
|
|
629
|
+
).start()
|
|
630
|
+
if b1_hbm is not None and bd3_id == 0:
|
|
631
|
+
pltpu.make_async_copy(
|
|
632
|
+
src_ref=b1_hbm.at[local_e_id, 1,
|
|
633
|
+
pl.ds(0, 1),
|
|
634
|
+
pl.ds(bf_id * bf, bf)],
|
|
635
|
+
dst_ref=b_b3_x2_vmem.at[bf_id % 2],
|
|
636
|
+
sem=local_sems.at[bw3_sem_id, 3],
|
|
637
|
+
).start()
|
|
476
638
|
|
|
477
639
|
def wait_fetch_bw1(local_e_id, bw1_sem_id, bf_id, bd1_id):
|
|
478
|
-
del local_e_id
|
|
640
|
+
del local_e_id
|
|
479
641
|
pltpu.make_async_copy(
|
|
480
642
|
src_ref=b_w1_x2_vmem.at[bw1_sem_id],
|
|
481
643
|
dst_ref=b_w1_x2_vmem.at[bw1_sem_id],
|
|
482
644
|
sem=local_sems.at[bw1_sem_id, 1],
|
|
483
645
|
).wait()
|
|
646
|
+
if w1_scale_hbm is not None:
|
|
647
|
+
pltpu.make_async_copy(
|
|
648
|
+
src_ref=b_w1_scale_x2_vmem.at[bw1_sem_id],
|
|
649
|
+
dst_ref=b_w1_scale_x2_vmem.at[bw1_sem_id],
|
|
650
|
+
sem=local_sems.at[bw1_sem_id, 1],
|
|
651
|
+
).wait()
|
|
652
|
+
if b1_hbm is not None and bd1_id == 0:
|
|
653
|
+
pltpu.make_async_copy(
|
|
654
|
+
src_ref=b_b1_x2_vmem.at[bf_id % 2],
|
|
655
|
+
dst_ref=b_b1_x2_vmem.at[bf_id % 2],
|
|
656
|
+
sem=local_sems.at[bw1_sem_id, 1],
|
|
657
|
+
).wait()
|
|
484
658
|
|
|
485
659
|
def wait_fetch_bw2(local_e_id, bw2_sem_id, bf_id, bd2_id):
|
|
486
|
-
del local_e_id
|
|
660
|
+
del local_e_id
|
|
487
661
|
pltpu.make_async_copy(
|
|
488
662
|
src_ref=b_w2_x2_vmem.at[bw2_sem_id],
|
|
489
663
|
dst_ref=b_w2_x2_vmem.at[bw2_sem_id],
|
|
490
664
|
sem=local_sems.at[bw2_sem_id, 2],
|
|
491
665
|
).wait()
|
|
666
|
+
if w2_scale_hbm is not None:
|
|
667
|
+
pltpu.make_async_copy(
|
|
668
|
+
src_ref=b_w2_scale_x2_vmem.at[bw2_sem_id],
|
|
669
|
+
dst_ref=b_w2_scale_x2_vmem.at[bw2_sem_id],
|
|
670
|
+
sem=local_sems.at[bw2_sem_id, 2],
|
|
671
|
+
).wait()
|
|
672
|
+
if b2_hbm is not None and bf_id == 0:
|
|
673
|
+
pltpu.make_async_copy(
|
|
674
|
+
src_ref=b_b2_x2_vmem.at[bd2_id % 2],
|
|
675
|
+
dst_ref=b_b2_x2_vmem.at[bd2_id % 2],
|
|
676
|
+
sem=local_sems.at[bw2_sem_id, 2],
|
|
677
|
+
).wait()
|
|
492
678
|
|
|
493
679
|
def wait_fetch_bw3(local_e_id, bw3_sem_id, bf_id, bd3_id):
|
|
494
|
-
del local_e_id
|
|
680
|
+
del local_e_id
|
|
495
681
|
pltpu.make_async_copy(
|
|
496
682
|
src_ref=b_w3_x2_vmem.at[bw3_sem_id],
|
|
497
683
|
dst_ref=b_w3_x2_vmem.at[bw3_sem_id],
|
|
498
684
|
sem=local_sems.at[bw3_sem_id, 3],
|
|
499
685
|
).wait()
|
|
686
|
+
if w1_scale_hbm is not None:
|
|
687
|
+
pltpu.make_async_copy(
|
|
688
|
+
src_ref=b_w3_scale_x2_vmem.at[bw3_sem_id],
|
|
689
|
+
dst_ref=b_w3_scale_x2_vmem.at[bw3_sem_id],
|
|
690
|
+
sem=local_sems.at[bw3_sem_id, 3],
|
|
691
|
+
).wait()
|
|
692
|
+
if b1_hbm is not None and bd3_id == 0:
|
|
693
|
+
pltpu.make_async_copy(
|
|
694
|
+
src_ref=b_b3_x2_vmem.at[bf_id % 2],
|
|
695
|
+
dst_ref=b_b3_x2_vmem.at[bf_id % 2],
|
|
696
|
+
sem=local_sems.at[bw3_sem_id, 3],
|
|
697
|
+
).wait()
|
|
500
698
|
|
|
501
699
|
def start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, bd2_id):
|
|
502
700
|
next_bd1_id = bd1_id + 1
|
|
@@ -520,18 +718,38 @@ def _fused_ep_moe_kernel(
|
|
|
520
718
|
def dynamic_ffn1(
|
|
521
719
|
t_b32_vmem,
|
|
522
720
|
w1_vmem,
|
|
721
|
+
w1_scale_vmem,
|
|
722
|
+
b1_vmem,
|
|
523
723
|
w3_vmem,
|
|
724
|
+
w3_scale_vmem,
|
|
725
|
+
b3_vmem,
|
|
524
726
|
acc1_vmem,
|
|
525
727
|
acc3_vmem,
|
|
526
728
|
dyn_sz,
|
|
527
729
|
should_init,
|
|
528
730
|
):
|
|
529
731
|
assert t_b32_vmem.shape == (bt * num_devices, bd1 // t_packing)
|
|
530
|
-
assert w1_vmem.shape == w3_vmem.shape == (t_packing,
|
|
732
|
+
assert w1_vmem.shape == w3_vmem.shape == (t_packing, bd1_per_t_packing,
|
|
531
733
|
bf)
|
|
532
734
|
assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf)
|
|
533
735
|
assert bd1 % (t_packing * 128) == 0, (bd1, t_packing)
|
|
534
736
|
assert bd1c % (t_packing * 128) == 0, (bd1c, t_packing)
|
|
737
|
+
if w1_scale_vmem is not None:
|
|
738
|
+
assert w1_scale_vmem.shape == (
|
|
739
|
+
t_packing,
|
|
740
|
+
bd1_per_t_packing // subc_quant_wsz,
|
|
741
|
+
1,
|
|
742
|
+
bf,
|
|
743
|
+
)
|
|
744
|
+
assert bd1c_per_t_packing == subc_quant_wsz
|
|
745
|
+
if w3_scale_vmem is not None:
|
|
746
|
+
assert w3_scale_vmem.shape == (
|
|
747
|
+
t_packing,
|
|
748
|
+
bd1_per_t_packing // subc_quant_wsz,
|
|
749
|
+
1,
|
|
750
|
+
bf,
|
|
751
|
+
)
|
|
752
|
+
assert bd1c_per_t_packing == subc_quant_wsz
|
|
535
753
|
|
|
536
754
|
num_loops = cdiv(dyn_sz, btc)
|
|
537
755
|
repack_ty = jnp.dtype(f"int{t_bitwidth}")
|
|
@@ -540,7 +758,7 @@ def _fused_ep_moe_kernel(
|
|
|
540
758
|
for bd1c_id in range(cdiv(bd1, bd1c)):
|
|
541
759
|
t_b32 = t_b32_vmem[
|
|
542
760
|
pl.ds(btc_id * btc, btc),
|
|
543
|
-
pl.ds(bd1c_id *
|
|
761
|
+
pl.ds(bd1c_id * bd1c_per_t_packing, bd1c_per_t_packing),
|
|
544
762
|
]
|
|
545
763
|
for p_id in range(t_packing):
|
|
546
764
|
t = pltpu.bitcast(t_b32.astype(repack_ty), t_dtype)
|
|
@@ -548,21 +766,64 @@ def _fused_ep_moe_kernel(
|
|
|
548
766
|
for bfc_id in range(cdiv(bf, bfc)):
|
|
549
767
|
w_slices = (
|
|
550
768
|
p_id,
|
|
551
|
-
pl.ds(bd1c_id *
|
|
552
|
-
|
|
769
|
+
pl.ds(bd1c_id * bd1c_per_t_packing,
|
|
770
|
+
bd1c_per_t_packing),
|
|
553
771
|
pl.ds(bfc_id * bfc, bfc),
|
|
554
772
|
)
|
|
555
773
|
w1 = w1_vmem[*w_slices]
|
|
556
774
|
acc1 = jnp.dot(t,
|
|
557
775
|
w1,
|
|
558
776
|
preferred_element_type=jnp.float32)
|
|
777
|
+
|
|
778
|
+
if w1_scale_vmem is not None:
|
|
779
|
+
w1_scale_slices = (
|
|
780
|
+
p_id,
|
|
781
|
+
bd1c_id,
|
|
782
|
+
pl.ds(0, 1),
|
|
783
|
+
pl.ds(bfc_id * bfc, bfc),
|
|
784
|
+
)
|
|
785
|
+
# TODO(jevinjiang): can use mosaic to load with stride 0.
|
|
786
|
+
w1_scale = jnp.broadcast_to(
|
|
787
|
+
w1_scale_vmem[*w1_scale_slices], acc1.shape)
|
|
788
|
+
acc1 *= w1_scale
|
|
789
|
+
|
|
559
790
|
w3 = w3_vmem[*w_slices]
|
|
791
|
+
|
|
560
792
|
acc3 = jnp.dot(t,
|
|
561
793
|
w3,
|
|
562
794
|
preferred_element_type=jnp.float32)
|
|
795
|
+
|
|
796
|
+
if w3_scale_vmem is not None:
|
|
797
|
+
w3_scale_slices = (
|
|
798
|
+
p_id,
|
|
799
|
+
bd1c_id,
|
|
800
|
+
pl.ds(0, 1),
|
|
801
|
+
pl.ds(bfc_id * bfc, bfc),
|
|
802
|
+
)
|
|
803
|
+
w3_scale = jnp.broadcast_to(
|
|
804
|
+
w3_scale_vmem[*w3_scale_slices], acc3.shape)
|
|
805
|
+
acc3 *= w3_scale
|
|
806
|
+
|
|
563
807
|
acc_slices = (pl.ds(btc_id * btc,
|
|
564
808
|
btc), pl.ds(bfc_id * bfc, bfc))
|
|
565
809
|
if should_init and p_id == bd1c_id == 0:
|
|
810
|
+
if b1_vmem is not None:
|
|
811
|
+
b1_scale_slices = (
|
|
812
|
+
pl.ds(0, 1),
|
|
813
|
+
pl.ds(bfc_id * bfc, bfc),
|
|
814
|
+
)
|
|
815
|
+
b1 = jnp.broadcast_to(
|
|
816
|
+
b1_vmem[*b1_scale_slices], acc1.shape)
|
|
817
|
+
acc1 += b1
|
|
818
|
+
if b3_vmem is not None:
|
|
819
|
+
b3_scale_slices = (
|
|
820
|
+
pl.ds(0, 1),
|
|
821
|
+
pl.ds(bfc_id * bfc, bfc),
|
|
822
|
+
)
|
|
823
|
+
b3 = jnp.broadcast_to(
|
|
824
|
+
b3_vmem[*b3_scale_slices], acc1.shape)
|
|
825
|
+
acc3 += b3
|
|
826
|
+
|
|
566
827
|
acc1_vmem[*acc_slices] = acc1
|
|
567
828
|
acc3_vmem[*acc_slices] = acc3
|
|
568
829
|
else:
|
|
@@ -575,22 +836,28 @@ def _fused_ep_moe_kernel(
|
|
|
575
836
|
acc1_vmem,
|
|
576
837
|
acc3_vmem,
|
|
577
838
|
w2_vmem,
|
|
839
|
+
w2_scale_vmem,
|
|
840
|
+
b2_vmem,
|
|
578
841
|
res_b32_vmem,
|
|
579
842
|
dyn_sz,
|
|
580
843
|
should_init,
|
|
581
844
|
):
|
|
582
|
-
assert res_b32_vmem.shape == (bt * num_devices,
|
|
583
|
-
assert w2_vmem.shape == (t_packing, bf,
|
|
584
|
-
w2_vmem.shape,
|
|
585
|
-
t_packing,
|
|
586
|
-
bf,
|
|
587
|
-
bd2_per_packing,
|
|
588
|
-
)
|
|
845
|
+
assert res_b32_vmem.shape == (bt * num_devices, bd2_per_t_packing)
|
|
846
|
+
assert w2_vmem.shape == (t_packing, bf, bd2_per_t_packing)
|
|
589
847
|
assert acc1_vmem.shape == acc3_vmem.shape == (bt * num_devices, bf)
|
|
590
848
|
assert bd2 % (t_packing * 128) == 0, (bd2, t_packing)
|
|
591
849
|
assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing)
|
|
592
850
|
assert t_dtype in (jnp.float32, jnp.bfloat16)
|
|
593
851
|
|
|
852
|
+
if w2_scale_vmem is not None:
|
|
853
|
+
assert w2_scale_vmem.shape == (
|
|
854
|
+
t_packing,
|
|
855
|
+
bf // subc_quant_wsz,
|
|
856
|
+
1,
|
|
857
|
+
bd2_per_t_packing,
|
|
858
|
+
)
|
|
859
|
+
assert bfc == subc_quant_wsz
|
|
860
|
+
|
|
594
861
|
num_loops = cdiv(dyn_sz, btc)
|
|
595
862
|
assert bd2c % (t_packing * 128) == 0, (bd2c, t_packing)
|
|
596
863
|
|
|
@@ -598,22 +865,47 @@ def _fused_ep_moe_kernel(
|
|
|
598
865
|
for bd2c_id in range(cdiv(bd2, bd2c)):
|
|
599
866
|
res_lst = []
|
|
600
867
|
for p_id in range(t_packing):
|
|
601
|
-
res = jnp.zeros((btc,
|
|
868
|
+
res = jnp.zeros((btc, bd2c_per_t_packing),
|
|
869
|
+
dtype=jnp.float32)
|
|
870
|
+
|
|
871
|
+
if b2_vmem is not None and should_init:
|
|
872
|
+
b2_scale_slices = (
|
|
873
|
+
p_id,
|
|
874
|
+
pl.ds(0, 1),
|
|
875
|
+
pl.ds(bd2c_id * bd2c_per_t_packing,
|
|
876
|
+
bd2c_per_t_packing),
|
|
877
|
+
)
|
|
878
|
+
b2 = jnp.broadcast_to(b2_vmem[*b2_scale_slices],
|
|
879
|
+
res.shape)
|
|
880
|
+
res += b2
|
|
881
|
+
|
|
602
882
|
for bfc_id in range(cdiv(bf, bfc)):
|
|
603
883
|
acc_slices = (pl.ds(btc_id * btc,
|
|
604
884
|
btc), pl.ds(bfc_id * bfc, bfc))
|
|
605
885
|
acc1 = acc1_vmem[*acc_slices]
|
|
606
886
|
acc3 = acc3_vmem[*acc_slices]
|
|
607
|
-
act =
|
|
887
|
+
act = activation_fn(acc1, acc3, act_fn)
|
|
608
888
|
w2 = w2_vmem[
|
|
609
889
|
p_id,
|
|
610
890
|
pl.ds(bfc_id * bfc, bfc),
|
|
611
891
|
pl.ds(bd2c_id *
|
|
612
|
-
|
|
892
|
+
bd2c_per_t_packing, bd2c_per_t_packing),
|
|
613
893
|
]
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
894
|
+
acc = jnp.dot(act,
|
|
895
|
+
w2,
|
|
896
|
+
preferred_element_type=jnp.float32)
|
|
897
|
+
if w2_scale_vmem is not None:
|
|
898
|
+
w2_scale_slices = (
|
|
899
|
+
p_id,
|
|
900
|
+
bfc_id,
|
|
901
|
+
pl.ds(0, 1),
|
|
902
|
+
pl.ds(bd2c_id * bd2c_per_t_packing,
|
|
903
|
+
bd2c_per_t_packing),
|
|
904
|
+
)
|
|
905
|
+
w2_scale = jnp.broadcast_to(
|
|
906
|
+
w2_scale_vmem[*w2_scale_slices], acc.shape)
|
|
907
|
+
acc *= w2_scale
|
|
908
|
+
res += acc
|
|
617
909
|
res = pltpu.bitcast(res, jnp.uint32)
|
|
618
910
|
if t_packing == 2:
|
|
619
911
|
res = res >> 16 << (16 * p_id)
|
|
@@ -626,7 +918,7 @@ def _fused_ep_moe_kernel(
|
|
|
626
918
|
res |= res_lst[i]
|
|
627
919
|
sliced_res_vmem = res_b32_vmem.at[
|
|
628
920
|
pl.ds(btc_id * btc, btc),
|
|
629
|
-
pl.ds(bd2c_id *
|
|
921
|
+
pl.ds(bd2c_id * bd2c_per_t_packing, bd2c_per_t_packing),
|
|
630
922
|
]
|
|
631
923
|
if should_init:
|
|
632
924
|
sliced_res_vmem[...] = res
|
|
@@ -655,21 +947,33 @@ def _fused_ep_moe_kernel(
|
|
|
655
947
|
e_id = my_id * local_num_experts + local_e_id
|
|
656
948
|
dyn_sz = expert_sizes_x2_smem[bt_sem_id, 0, e_id]
|
|
657
949
|
|
|
658
|
-
|
|
659
|
-
|
|
950
|
+
bd1_per_t_packing = bd1 // t_packing
|
|
951
|
+
bd2_per_t_packing = bd2 // t_packing
|
|
660
952
|
|
|
661
953
|
for bf_id in range(num_bf):
|
|
662
954
|
for bd1_id in range(num_bd1):
|
|
663
955
|
start_fetch_next_bw(local_e_id, bw_sem_id, bf_id, bd1_id, 0)
|
|
956
|
+
w1_scale_vmem = (None if b_w1_scale_x2_vmem is None else
|
|
957
|
+
b_w1_scale_x2_vmem.at[bw_sem_id])
|
|
958
|
+
w3_scale_vmem = (None if b_w3_scale_x2_vmem is None else
|
|
959
|
+
b_w3_scale_x2_vmem.at[bw_sem_id])
|
|
960
|
+
b1_vmem = None if b_b1_x2_vmem is None else b_b1_x2_vmem.at[
|
|
961
|
+
bf_id % 2]
|
|
962
|
+
b3_vmem = None if b_b3_x2_vmem is None else b_b3_x2_vmem.at[
|
|
963
|
+
bf_id % 2]
|
|
664
964
|
wait_fetch_bw1(local_e_id, bw_sem_id, bf_id, bd1_id)
|
|
665
965
|
wait_fetch_bw3(local_e_id, bw_sem_id, bf_id, bd1_id)
|
|
666
966
|
|
|
667
967
|
dynamic_ffn1(
|
|
668
968
|
t_b32_vmem=a2a_s_b32_vmem.at[
|
|
669
969
|
...,
|
|
670
|
-
pl.ds(bd1_id *
|
|
970
|
+
pl.ds(bd1_id * bd1_per_t_packing, bd1_per_t_packing)],
|
|
671
971
|
w1_vmem=b_w1_x2_vmem.at[bw_sem_id],
|
|
972
|
+
w1_scale_vmem=w1_scale_vmem,
|
|
973
|
+
b1_vmem=b1_vmem,
|
|
672
974
|
w3_vmem=b_w3_x2_vmem.at[bw_sem_id],
|
|
975
|
+
w3_scale_vmem=w3_scale_vmem,
|
|
976
|
+
b3_vmem=b3_vmem,
|
|
673
977
|
acc1_vmem=b_acc1_vmem,
|
|
674
978
|
acc3_vmem=b_acc3_vmem,
|
|
675
979
|
dyn_sz=dyn_sz,
|
|
@@ -684,13 +988,19 @@ def _fused_ep_moe_kernel(
|
|
|
684
988
|
if bf_id == bd2_id == 0:
|
|
685
989
|
wait_a2a_gather_send(bt_id, e_sem_id, local_e_id - 2)
|
|
686
990
|
|
|
991
|
+
w2_scale_vmem = (None if b_w2_scale_x2_vmem is None else
|
|
992
|
+
b_w2_scale_x2_vmem.at[bw_sem_id])
|
|
993
|
+
b2_vmem = None if b_b2_x2_vmem is None else b_b2_x2_vmem.at[
|
|
994
|
+
bd2_id % 2]
|
|
687
995
|
dynamic_ffn2(
|
|
688
996
|
acc1_vmem=b_acc1_vmem,
|
|
689
997
|
acc3_vmem=b_acc3_vmem,
|
|
690
998
|
w2_vmem=b_w2_x2_vmem.at[bw_sem_id],
|
|
999
|
+
w2_scale_vmem=w2_scale_vmem,
|
|
1000
|
+
b2_vmem=b2_vmem,
|
|
691
1001
|
res_b32_vmem=a2a_s_acc_b32_vmem.at[
|
|
692
1002
|
...,
|
|
693
|
-
pl.ds(bd2_id *
|
|
1003
|
+
pl.ds(bd2_id * bd2_per_t_packing, bd2_per_t_packing)],
|
|
694
1004
|
dyn_sz=dyn_sz,
|
|
695
1005
|
should_init=(bf_id == 0),
|
|
696
1006
|
)
|
|
@@ -757,7 +1067,7 @@ def _fused_ep_moe_kernel(
|
|
|
757
1067
|
b_gating = b_gating_x2_vmem[bt_sem_id]
|
|
758
1068
|
b_gating_score = jax.nn.softmax(b_gating, axis=-1)
|
|
759
1069
|
top_k_logits_lst, t2e_routing, expert_sizes, expert_starts = get_top_k(
|
|
760
|
-
b_gating_score, top_k)
|
|
1070
|
+
b_gating_score, top_k, renormalize_topk_logits)
|
|
761
1071
|
|
|
762
1072
|
all_reduce_metadata(bt_sem_id, t2e_routing, expert_starts,
|
|
763
1073
|
expert_sizes)
|
|
@@ -827,6 +1137,9 @@ def _fused_ep_moe_kernel(
|
|
|
827
1137
|
static_argnames=[
|
|
828
1138
|
"mesh",
|
|
829
1139
|
"top_k",
|
|
1140
|
+
"renormalize_topk_logits",
|
|
1141
|
+
"act_fn",
|
|
1142
|
+
"subc_quant_wsz",
|
|
830
1143
|
"bt",
|
|
831
1144
|
"bf",
|
|
832
1145
|
"bd1",
|
|
@@ -845,7 +1158,18 @@ def fused_ep_moe(
|
|
|
845
1158
|
w2: jax.Array, # (num_experts, intermediate_size, hidden_size)
|
|
846
1159
|
gating_output: jax.Array, # (num_tokens, num_experts)
|
|
847
1160
|
top_k: int,
|
|
1161
|
+
renormalize_topk_logits: bool = False,
|
|
1162
|
+
act_fn: str = "silu",
|
|
848
1163
|
*,
|
|
1164
|
+
subc_quant_wsz: int | None = None,
|
|
1165
|
+
w1_scale: (
|
|
1166
|
+
jax.Array | None
|
|
1167
|
+
) = None, # (num_experts, 2, cdiv(hidden_size, subc_quant_wsz), intermediate_size)
|
|
1168
|
+
w2_scale: (
|
|
1169
|
+
jax.Array | None
|
|
1170
|
+
) = None, # (num_experts, cdiv(intermediate_size, subc_quant_wsz), hidden_size)
|
|
1171
|
+
b1: jax.Array | None = None, # (num_experts, 2, intermediate_size)
|
|
1172
|
+
b2: jax.Array | None = None, # (num_experts, hidden_size)
|
|
849
1173
|
# Kernel tuning parameters.
|
|
850
1174
|
bt: int,
|
|
851
1175
|
bf: int,
|
|
@@ -855,18 +1179,19 @@ def fused_ep_moe(
|
|
|
855
1179
|
bfc: int,
|
|
856
1180
|
bd1c: int,
|
|
857
1181
|
bd2c: int,
|
|
858
|
-
ep_axis_name: str =
|
|
1182
|
+
ep_axis_name: str = "model",
|
|
859
1183
|
):
|
|
1184
|
+
# TODO(jevinjiang): move all these assertions to validation function.
|
|
860
1185
|
# Assert all other axes have length of 1
|
|
861
|
-
assert len(mesh.shape) == 2, "Expect 2D mesh
|
|
862
|
-
assert
|
|
863
|
-
|
|
1186
|
+
assert len(mesh.shape) == 2, "Expect 2D mesh"
|
|
1187
|
+
assert ("data" in mesh.shape
|
|
1188
|
+
and mesh.shape["data"] == 1), "Expect data axis size of 1"
|
|
864
1189
|
|
|
865
1190
|
ep_size = mesh.shape[ep_axis_name]
|
|
866
1191
|
num_devices = ep_size
|
|
867
1192
|
|
|
868
1193
|
num_tokens, actual_hidden_size = tokens.shape
|
|
869
|
-
num_experts,
|
|
1194
|
+
num_experts, actual_intermediate_size, _ = w2.shape
|
|
870
1195
|
|
|
871
1196
|
assert num_tokens % ep_size == 0
|
|
872
1197
|
assert num_experts % ep_size == 0
|
|
@@ -874,26 +1199,18 @@ def fused_ep_moe(
|
|
|
874
1199
|
local_num_tokens = num_tokens // ep_size
|
|
875
1200
|
# local_num_experts = num_experts // ep_size
|
|
876
1201
|
padded_num_experts = align_to(num_experts, 128)
|
|
877
|
-
|
|
878
1202
|
t_dtype = tokens.dtype
|
|
879
1203
|
t_packing = get_dtype_packing(t_dtype)
|
|
880
|
-
hidden_size = align_to(actual_hidden_size, 128 * t_packing)
|
|
881
|
-
if hidden_size != actual_hidden_size:
|
|
882
|
-
tokens = jnp.pad(
|
|
883
|
-
tokens,
|
|
884
|
-
((0, 0), (0, hidden_size - actual_hidden_size)),
|
|
885
|
-
constant_values=0,
|
|
886
|
-
)
|
|
887
|
-
tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing)
|
|
888
|
-
bt = min(bt, local_num_tokens)
|
|
889
|
-
bf = min(bf, intermediate_size)
|
|
890
|
-
bd1 = min(bd1, hidden_size)
|
|
891
|
-
bd2 = min(bd2, hidden_size)
|
|
892
1204
|
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
1205
|
+
if subc_quant_wsz is not None:
|
|
1206
|
+
if subc_quant_wsz % 256 != 0:
|
|
1207
|
+
raise NotImplementedError(
|
|
1208
|
+
"Sub-quantized window is not aligned to 256.")
|
|
1209
|
+
# We force compute size of contracting dim to subc_quant_wsz. So we can
|
|
1210
|
+
# apply same scale after matmul and accumulation.
|
|
1211
|
+
bd1c = subc_quant_wsz * t_packing
|
|
1212
|
+
bfc = subc_quant_wsz
|
|
1213
|
+
|
|
897
1214
|
assert bfc % 128 == 0
|
|
898
1215
|
assert bd1c % (t_packing * 128) == 0
|
|
899
1216
|
assert bd2c % (t_packing * 128) == 0
|
|
@@ -901,6 +1218,41 @@ def fused_ep_moe(
|
|
|
901
1218
|
assert bd1 % bd1c == 0
|
|
902
1219
|
assert bd2 % bd2c == 0
|
|
903
1220
|
|
|
1221
|
+
btc = min(btc, bt * num_devices)
|
|
1222
|
+
hidden_size = align_to(actual_hidden_size, 128 * t_packing)
|
|
1223
|
+
# TODO(jevinjiang): instead of padding outside the kernel, we can try dynammic
|
|
1224
|
+
# masking inside the kernel.
|
|
1225
|
+
hidden_size = align_to(hidden_size, bd1)
|
|
1226
|
+
hidden_size = align_to(hidden_size, bd2)
|
|
1227
|
+
intermediate_size = align_to(actual_intermediate_size, bf)
|
|
1228
|
+
|
|
1229
|
+
# TODO(jevinjiang): we should dump scale as the kernel expected shape in the
|
|
1230
|
+
# checkpoint offline or reshape right after weight loading.
|
|
1231
|
+
if w1_scale is not None:
|
|
1232
|
+
assert w1_scale.shape[0] == w1.shape[0]
|
|
1233
|
+
assert w1_scale.shape[1] == w1.shape[1] == 2
|
|
1234
|
+
assert w1_scale.shape[2] == cdiv(w1.shape[2], subc_quant_wsz)
|
|
1235
|
+
assert w1_scale.shape[3] == w1.shape[3]
|
|
1236
|
+
w1_scale = jnp.expand_dims(w1_scale.astype(jnp.float32), axis=-2)
|
|
1237
|
+
|
|
1238
|
+
if w2_scale is not None:
|
|
1239
|
+
assert w2_scale.shape[0] == w2.shape[0]
|
|
1240
|
+
assert w2_scale.shape[1] == cdiv(w2.shape[1], subc_quant_wsz)
|
|
1241
|
+
assert w2_scale.shape[2] == w2.shape[2]
|
|
1242
|
+
w2_scale = jnp.expand_dims(w2_scale.astype(jnp.float32), axis=-2)
|
|
1243
|
+
|
|
1244
|
+
if b1 is not None:
|
|
1245
|
+
assert b1.shape[0] == w1.shape[0]
|
|
1246
|
+
assert b1.shape[1] == w1.shape[1] == 2
|
|
1247
|
+
assert b1.shape[2] == w1.shape[3]
|
|
1248
|
+
b1 = jnp.expand_dims(b1.astype(jnp.float32), axis=-2)
|
|
1249
|
+
|
|
1250
|
+
if b2 is not None:
|
|
1251
|
+
assert b2.shape[0] == w2.shape[0]
|
|
1252
|
+
assert b2.shape[1] == w2.shape[2]
|
|
1253
|
+
b2 = jnp.expand_dims(b2.astype(jnp.float32), axis=-2)
|
|
1254
|
+
|
|
1255
|
+
# Prepare inputs for the kernel.
|
|
904
1256
|
if padded_num_experts != gating_output.shape[-1]:
|
|
905
1257
|
gating_output = jnp.pad(
|
|
906
1258
|
gating_output,
|
|
@@ -908,13 +1260,92 @@ def fused_ep_moe(
|
|
|
908
1260
|
constant_values=-jnp.inf,
|
|
909
1261
|
)
|
|
910
1262
|
|
|
911
|
-
|
|
1263
|
+
if (hidden_size != actual_hidden_size
|
|
1264
|
+
or intermediate_size != actual_intermediate_size):
|
|
1265
|
+
tokens = jnp.pad(
|
|
1266
|
+
tokens,
|
|
1267
|
+
((0, 0), (0, hidden_size - actual_hidden_size)),
|
|
1268
|
+
constant_values=0,
|
|
1269
|
+
)
|
|
1270
|
+
w1 = jnp.pad(
|
|
1271
|
+
w1,
|
|
1272
|
+
(
|
|
1273
|
+
(0, 0),
|
|
1274
|
+
(0, 0),
|
|
1275
|
+
(0, hidden_size - actual_hidden_size),
|
|
1276
|
+
(0, intermediate_size - actual_intermediate_size),
|
|
1277
|
+
),
|
|
1278
|
+
constant_values=0,
|
|
1279
|
+
)
|
|
1280
|
+
w2 = jnp.pad(
|
|
1281
|
+
w2,
|
|
1282
|
+
(
|
|
1283
|
+
(0, 0),
|
|
1284
|
+
(0, intermediate_size - actual_intermediate_size),
|
|
1285
|
+
(0, hidden_size - actual_hidden_size),
|
|
1286
|
+
),
|
|
1287
|
+
constant_values=0,
|
|
1288
|
+
)
|
|
1289
|
+
if w1_scale is not None:
|
|
1290
|
+
w1_scale = jnp.pad(
|
|
1291
|
+
w1_scale,
|
|
1292
|
+
(
|
|
1293
|
+
(0, 0),
|
|
1294
|
+
(0, 0),
|
|
1295
|
+
(0,
|
|
1296
|
+
cdiv(hidden_size, subc_quant_wsz) - w1_scale.shape[-3]),
|
|
1297
|
+
(0, 0),
|
|
1298
|
+
(0, intermediate_size - w1_scale.shape[-1]),
|
|
1299
|
+
),
|
|
1300
|
+
constant_values=0,
|
|
1301
|
+
)
|
|
1302
|
+
if w2_scale is not None:
|
|
1303
|
+
w2_scale = jnp.pad(
|
|
1304
|
+
w2_scale,
|
|
1305
|
+
(
|
|
1306
|
+
(0, 0),
|
|
1307
|
+
(0, cdiv(intermediate_size, subc_quant_wsz) -
|
|
1308
|
+
w2_scale.shape[-3]),
|
|
1309
|
+
(0, 0),
|
|
1310
|
+
(0, hidden_size - w2_scale.shape[-1]),
|
|
1311
|
+
),
|
|
1312
|
+
constant_values=0,
|
|
1313
|
+
)
|
|
1314
|
+
if b1 is not None:
|
|
1315
|
+
b1 = jnp.pad(
|
|
1316
|
+
b1,
|
|
1317
|
+
(
|
|
1318
|
+
(0, 0),
|
|
1319
|
+
(0, 0),
|
|
1320
|
+
(0, 0),
|
|
1321
|
+
(0, intermediate_size - b1.shape[-1]),
|
|
1322
|
+
),
|
|
1323
|
+
constant_values=0,
|
|
1324
|
+
)
|
|
1325
|
+
if b2 is not None:
|
|
1326
|
+
b2 = jnp.pad(
|
|
1327
|
+
b2,
|
|
1328
|
+
(
|
|
1329
|
+
(0, 0),
|
|
1330
|
+
(0, 0),
|
|
1331
|
+
(0, hidden_size - b2.shape[-1]),
|
|
1332
|
+
),
|
|
1333
|
+
constant_values=0,
|
|
1334
|
+
)
|
|
1335
|
+
|
|
1336
|
+
tokens = tokens.reshape(-1, t_packing, hidden_size // t_packing)
|
|
1337
|
+
|
|
1338
|
+
hbm_block_spec = pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM)
|
|
1339
|
+
scope_name = f"fused_moe_k-{top_k}_renorm-{renormalize_topk_logits}_bt-{bt}-{btc}_bf-{bf}-{bfc}_bd1-{bd1}-{bd1c}_bd2-{bd2}-{bd2c}"
|
|
912
1340
|
fused_moe = jax.named_scope(scope_name)(
|
|
913
1341
|
pl.pallas_call(
|
|
914
1342
|
functools.partial(
|
|
915
1343
|
_fused_ep_moe_kernel,
|
|
916
1344
|
top_k=top_k,
|
|
1345
|
+
renormalize_topk_logits=renormalize_topk_logits,
|
|
917
1346
|
ep_axis_name=ep_axis_name,
|
|
1347
|
+
act_fn=act_fn,
|
|
1348
|
+
subc_quant_wsz=subc_quant_wsz,
|
|
918
1349
|
bt=bt,
|
|
919
1350
|
bf=bf,
|
|
920
1351
|
bd1=bd1,
|
|
@@ -929,11 +1360,17 @@ def fused_ep_moe(
|
|
|
929
1360
|
grid_spec=pltpu.PrefetchScalarGridSpec(
|
|
930
1361
|
num_scalar_prefetch=0,
|
|
931
1362
|
in_specs=[
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
1363
|
+
hbm_block_spec, # tokens_hbm
|
|
1364
|
+
hbm_block_spec, # w1_hbm
|
|
1365
|
+
hbm_block_spec, # w2_hbm
|
|
1366
|
+
None
|
|
1367
|
+
if w1_scale is None else hbm_block_spec, # w1_scale_hbm
|
|
1368
|
+
None
|
|
1369
|
+
if w2_scale is None else hbm_block_spec, # w2_scale_hbm
|
|
1370
|
+
None if b1 is None else hbm_block_spec, # b1_hbm
|
|
1371
|
+
None if b2 is None else hbm_block_spec, # b2_hbm
|
|
1372
|
+
hbm_block_spec, # gating_output_hbm
|
|
1373
|
+
hbm_block_spec, # a2a_g_hbm
|
|
937
1374
|
],
|
|
938
1375
|
out_specs=pl.BlockSpec(memory_space=pltpu.MemorySpace.HBM),
|
|
939
1376
|
scratch_shapes=([
|
|
@@ -984,6 +1421,67 @@ def fused_ep_moe(
|
|
|
984
1421
|
pltpu.VMEM((2, t_packing, bd1 // t_packing, bf), w1.dtype),
|
|
985
1422
|
# b_w2_x2_vmem
|
|
986
1423
|
pltpu.VMEM((2, t_packing, bf, bd2 // t_packing), w2.dtype),
|
|
1424
|
+
# b_w1_scale_x2_vmem
|
|
1425
|
+
(None if w1_scale is None else pltpu.VMEM(
|
|
1426
|
+
(
|
|
1427
|
+
2,
|
|
1428
|
+
t_packing,
|
|
1429
|
+
bd1 // t_packing // subc_quant_wsz,
|
|
1430
|
+
1,
|
|
1431
|
+
bf,
|
|
1432
|
+
),
|
|
1433
|
+
jnp.float32,
|
|
1434
|
+
)),
|
|
1435
|
+
# b_w3_scale_x2_vmem
|
|
1436
|
+
(None if w1_scale is None else pltpu.VMEM(
|
|
1437
|
+
(
|
|
1438
|
+
2,
|
|
1439
|
+
t_packing,
|
|
1440
|
+
bd1 // t_packing // subc_quant_wsz,
|
|
1441
|
+
1,
|
|
1442
|
+
bf,
|
|
1443
|
+
),
|
|
1444
|
+
jnp.float32,
|
|
1445
|
+
)),
|
|
1446
|
+
# b_w2_scale_x2_vmem
|
|
1447
|
+
(None if w2_scale is None else pltpu.VMEM(
|
|
1448
|
+
(
|
|
1449
|
+
2,
|
|
1450
|
+
t_packing,
|
|
1451
|
+
bf // subc_quant_wsz,
|
|
1452
|
+
1,
|
|
1453
|
+
bd2 // t_packing,
|
|
1454
|
+
),
|
|
1455
|
+
jnp.float32,
|
|
1456
|
+
)),
|
|
1457
|
+
# b_b1_x2_vmem
|
|
1458
|
+
(None if b1 is None else pltpu.VMEM(
|
|
1459
|
+
(
|
|
1460
|
+
2,
|
|
1461
|
+
1,
|
|
1462
|
+
bf,
|
|
1463
|
+
),
|
|
1464
|
+
jnp.float32,
|
|
1465
|
+
)),
|
|
1466
|
+
# b_b3_x2_vmem
|
|
1467
|
+
(None if b1 is None else pltpu.VMEM(
|
|
1468
|
+
(
|
|
1469
|
+
2,
|
|
1470
|
+
1,
|
|
1471
|
+
bf,
|
|
1472
|
+
),
|
|
1473
|
+
jnp.float32,
|
|
1474
|
+
)),
|
|
1475
|
+
# b_b2_x2_vmem
|
|
1476
|
+
(None if b2 is None else pltpu.VMEM(
|
|
1477
|
+
(
|
|
1478
|
+
2,
|
|
1479
|
+
t_packing,
|
|
1480
|
+
1,
|
|
1481
|
+
bd2 // t_packing,
|
|
1482
|
+
),
|
|
1483
|
+
jnp.float32,
|
|
1484
|
+
)),
|
|
987
1485
|
# b_acc_vmem
|
|
988
1486
|
pltpu.VMEM((bt * num_devices, 1, bf * 2), jnp.float32),
|
|
989
1487
|
# local_sems
|
|
@@ -1006,21 +1504,50 @@ def fused_ep_moe(
|
|
|
1006
1504
|
))
|
|
1007
1505
|
|
|
1008
1506
|
@jax.jit
|
|
1009
|
-
@
|
|
1010
|
-
shard_map.shard_map,
|
|
1507
|
+
@jax.shard_map(
|
|
1011
1508
|
mesh=mesh,
|
|
1012
|
-
in_specs=(
|
|
1013
|
-
|
|
1509
|
+
in_specs=(
|
|
1510
|
+
P(ep_axis_name), # tokens_hbm
|
|
1511
|
+
P(ep_axis_name), # w1_hbm
|
|
1512
|
+
P(ep_axis_name), # w2_hbm
|
|
1513
|
+
None if w1_scale is None else P(ep_axis_name), # w1_scale_hbm
|
|
1514
|
+
None if w2_scale is None else P(ep_axis_name), # w2_scale_hbm
|
|
1515
|
+
None if b1 is None else P(ep_axis_name), # b1_hbm
|
|
1516
|
+
None if b2 is None else P(ep_axis_name), # b2_hbm
|
|
1517
|
+
P(ep_axis_name), # gating_output_hbm
|
|
1518
|
+
P(), # a2a_g_hbm
|
|
1519
|
+
),
|
|
1014
1520
|
out_specs=P(ep_axis_name),
|
|
1015
|
-
|
|
1521
|
+
check_vma=False,
|
|
1016
1522
|
)
|
|
1017
|
-
def kernel(
|
|
1523
|
+
def kernel(
|
|
1524
|
+
tokens,
|
|
1525
|
+
w1,
|
|
1526
|
+
w2,
|
|
1527
|
+
w1_scale,
|
|
1528
|
+
w2_scale,
|
|
1529
|
+
b1,
|
|
1530
|
+
b2,
|
|
1531
|
+
gating_output,
|
|
1532
|
+
a2a_g_hbm_scratch,
|
|
1533
|
+
):
|
|
1018
1534
|
return fused_moe(
|
|
1019
|
-
pltpu.with_memory_space_constraint(tokens,
|
|
1020
|
-
|
|
1021
|
-
pltpu.with_memory_space_constraint(
|
|
1022
|
-
pltpu.with_memory_space_constraint(
|
|
1023
|
-
pltpu.with_memory_space_constraint(
|
|
1535
|
+
pltpu.with_memory_space_constraint(tokens,
|
|
1536
|
+
pltpu.HBM), # tokens_hbm
|
|
1537
|
+
pltpu.with_memory_space_constraint(w1, pltpu.HBM), # w1_hbm
|
|
1538
|
+
pltpu.with_memory_space_constraint(w2, pltpu.HBM), # w2_hbm
|
|
1539
|
+
(None if w1_scale is None else pltpu.with_memory_space_constraint(
|
|
1540
|
+
w1_scale, pltpu.HBM)), # w1_scale_hbm
|
|
1541
|
+
(None if w2_scale is None else pltpu.with_memory_space_constraint(
|
|
1542
|
+
w2_scale, pltpu.HBM)), # w2_scale_hbm
|
|
1543
|
+
(None if b1 is None else pltpu.with_memory_space_constraint(
|
|
1544
|
+
b1, pltpu.HBM)), # b1_hbm
|
|
1545
|
+
(None if b2 is None else pltpu.with_memory_space_constraint(
|
|
1546
|
+
b2, pltpu.HBM)), # b2_hbm
|
|
1547
|
+
pltpu.with_memory_space_constraint(gating_output,
|
|
1548
|
+
pltpu.HBM), # gating_output_hbm
|
|
1549
|
+
pltpu.with_memory_space_constraint(a2a_g_hbm_scratch,
|
|
1550
|
+
pltpu.HBM), # a2a_g_hbm
|
|
1024
1551
|
)
|
|
1025
1552
|
|
|
1026
1553
|
a2a_g_hbm_scratch = pl.empty(
|
|
@@ -1029,6 +1556,10 @@ def fused_ep_moe(
|
|
|
1029
1556
|
tokens,
|
|
1030
1557
|
w1,
|
|
1031
1558
|
w2,
|
|
1559
|
+
w1_scale,
|
|
1560
|
+
w2_scale,
|
|
1561
|
+
b1,
|
|
1562
|
+
b2,
|
|
1032
1563
|
gating_output,
|
|
1033
1564
|
a2a_g_hbm_scratch,
|
|
1034
1565
|
)
|