tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +34 -303
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
- tests/lora/test_layers.py +6 -0
- tests/lora/utils.py +8 -0
- tests/test_envs.py +11 -32
- tests/test_utils.py +2 -1
- tpu_inference/__init__.py +3 -22
- tpu_inference/core/disagg_utils.py +8 -6
- tpu_inference/distributed/tpu_connector.py +4 -3
- tpu_inference/distributed/utils.py +2 -3
- tpu_inference/envs.py +8 -61
- tpu_inference/executors/ray_distributed_executor.py +2 -9
- tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +145 -266
- tpu_inference/layers/common/attention_interface.py +1 -7
- tpu_inference/layers/common/sharding.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +208 -170
- tpu_inference/layers/vllm/quantization/common.py +1 -6
- tpu_inference/layers/vllm/quantization/mxfp4.py +73 -138
- tpu_inference/layers/vllm/quantization/unquantized.py +64 -58
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +2 -1
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/common/model_loader.py +10 -43
- tpu_inference/models/jax/llama3.py +1 -2
- tpu_inference/models/jax/llama_eagle3.py +5 -8
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +1 -2
- tpu_inference/models/jax/qwen2_5_vl.py +48 -163
- tpu_inference/models/jax/qwen3.py +1 -2
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
- tpu_inference/models/jax/utils/weight_utils.py +143 -198
- tpu_inference/models/vllm/vllm_model_wrapper.py +8 -14
- tpu_inference/platforms/tpu_platform.py +31 -37
- tpu_inference/runner/compilation_manager.py +58 -141
- tpu_inference/runner/kv_cache.py +1 -1
- tpu_inference/runner/kv_cache_manager.py +18 -17
- tpu_inference/runner/persistent_batch_manager.py +2 -40
- tpu_inference/runner/structured_decoding_manager.py +3 -2
- tpu_inference/runner/tpu_runner.py +147 -271
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +21 -71
- tpu_inference/tpu_info.py +3 -4
- tpu_inference/utils.py +13 -36
- tpu_inference/worker/tpu_worker.py +25 -162
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +3 -4
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +55 -50
- tpu_inference/models/jax/llama_guard_4.py +0 -361
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
|
@@ -2,16 +2,17 @@ import functools
|
|
|
2
2
|
|
|
3
3
|
import jax
|
|
4
4
|
from jax import numpy as jnp
|
|
5
|
-
from jax import shard_map
|
|
6
5
|
from jax.experimental.pallas.ops.tpu.megablox.gmm import gmm
|
|
7
|
-
from jax.
|
|
8
|
-
from jax.sharding import
|
|
6
|
+
from jax.experimental.shard_map import shard_map
|
|
7
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
9
8
|
|
|
10
9
|
from tpu_inference.layers.vllm.linear_common import \
|
|
11
10
|
slice_sharded_tensor_for_concatenation
|
|
12
11
|
|
|
12
|
+
P = PartitionSpec
|
|
13
13
|
|
|
14
|
-
|
|
14
|
+
|
|
15
|
+
def activation_fn(activation: str, x1, x2):
|
|
15
16
|
match activation:
|
|
16
17
|
case "silu":
|
|
17
18
|
return jax.nn.silu(x1) * x2
|
|
@@ -22,10 +23,7 @@ def activation_fn(activation: str, x1: jax.Array, x2: jax.Array) -> jax.Array:
|
|
|
22
23
|
f"FusedMoE does not support {activation} activation")
|
|
23
24
|
|
|
24
25
|
|
|
25
|
-
def _swigluoai(x1
|
|
26
|
-
x2: jax.Array,
|
|
27
|
-
alpha=1.702,
|
|
28
|
-
limit=7.0) -> jax.Array:
|
|
26
|
+
def _swigluoai(x1, x2, alpha=1.702, limit=7.0):
|
|
29
27
|
x1 = jnp.clip(x1, a_max=limit)
|
|
30
28
|
x2 = jnp.clip(x2, a_min=-limit, a_max=limit)
|
|
31
29
|
|
|
@@ -105,53 +103,40 @@ def tensor_sharded_gmm_merged_column_parallel(
|
|
|
105
103
|
rhs: jax.Array,
|
|
106
104
|
rhs_bias: jax.Array | None,
|
|
107
105
|
group_sizes: jax.Array,
|
|
106
|
+
transpose_rhs: bool,
|
|
108
107
|
mesh: Mesh,
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
108
|
+
intermediate_size: int,
|
|
109
|
+
) -> jax.Array:
|
|
110
|
+
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
|
|
111
|
+
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
|
|
112
|
+
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
|
|
113
|
+
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
|
|
114
|
+
|
|
115
|
+
_gmm = functools.partial(
|
|
116
|
+
gmm,
|
|
117
|
+
preferred_element_type=lhs.dtype,
|
|
118
|
+
tiling=(tm, tk, tn),
|
|
119
|
+
transpose_rhs=transpose_rhs,
|
|
120
|
+
group_offset=jnp.array(0),
|
|
121
|
+
)
|
|
123
122
|
|
|
124
123
|
gmm_result = shard_map(
|
|
125
124
|
_gmm,
|
|
126
125
|
mesh=mesh,
|
|
127
|
-
in_specs=(P(
|
|
128
|
-
out_specs=(P(
|
|
129
|
-
|
|
126
|
+
in_specs=(P(), P(None, "model", None), P()),
|
|
127
|
+
out_specs=(P(None, "model")),
|
|
128
|
+
check_rep=False,
|
|
130
129
|
)(lhs, rhs, group_sizes)
|
|
131
130
|
|
|
132
131
|
if rhs_bias is not None:
|
|
132
|
+
rhs_bis = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
|
|
133
|
+
gmm_result = (gmm_result + rhs_bis).astype(gmm_result.dtype)
|
|
133
134
|
|
|
134
|
-
|
|
135
|
-
rhs_bias = jnp.repeat(
|
|
136
|
-
rhs_bias_local,
|
|
137
|
-
group_sizes_global,
|
|
138
|
-
0,
|
|
139
|
-
total_repeat_length=gmm_result_local.shape[0])
|
|
140
|
-
return gmm_result_local + rhs_bias
|
|
141
|
-
|
|
142
|
-
gmm_result = shard_map(
|
|
143
|
-
_add_bias,
|
|
144
|
-
mesh=mesh,
|
|
145
|
-
in_specs=(P("data", "model"), P(None, "model"), P("data")),
|
|
146
|
-
out_specs=(P("data", "model")),
|
|
147
|
-
)(gmm_result, rhs_bias, group_sizes)
|
|
148
|
-
gmm_result = gmm_result.astype(lhs.dtype)
|
|
149
|
-
|
|
150
|
-
tp_size = mesh.shape["model"]
|
|
151
|
-
intermediate_size = gmm_result.shape[-1] // 2
|
|
135
|
+
n_shards = mesh.shape["model"]
|
|
152
136
|
output_sizes = [intermediate_size, intermediate_size]
|
|
137
|
+
|
|
153
138
|
return slice_sharded_tensor_for_concatenation(gmm_result, output_sizes,
|
|
154
|
-
|
|
139
|
+
n_shards)
|
|
155
140
|
|
|
156
141
|
|
|
157
142
|
def tensor_sharded_gmm_row_parallel(
|
|
@@ -159,75 +144,74 @@ def tensor_sharded_gmm_row_parallel(
|
|
|
159
144
|
rhs: jax.Array,
|
|
160
145
|
rhs_bias: jax.Array | None,
|
|
161
146
|
group_sizes: jax.Array,
|
|
147
|
+
transpose_rhs: bool,
|
|
162
148
|
mesh: Mesh,
|
|
163
149
|
) -> jax.Array:
|
|
150
|
+
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
|
|
151
|
+
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
|
|
152
|
+
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
|
|
153
|
+
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
|
|
154
|
+
|
|
155
|
+
_gmm = functools.partial(
|
|
156
|
+
gmm,
|
|
157
|
+
preferred_element_type=lhs.dtype,
|
|
158
|
+
tiling=(tm, tk, tn),
|
|
159
|
+
transpose_rhs=transpose_rhs,
|
|
160
|
+
group_offset=jnp.array(0),
|
|
161
|
+
)
|
|
164
162
|
|
|
165
163
|
def _gmm_all_reduce(lhs, rhs, group_sizes):
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
out = gmm(
|
|
169
|
-
lhs,
|
|
170
|
-
rhs,
|
|
171
|
-
group_sizes,
|
|
172
|
-
preferred_element_type=lhs.dtype,
|
|
173
|
-
tiling=(tm, tk, tn),
|
|
174
|
-
transpose_rhs=True,
|
|
175
|
-
group_offset=jnp.array(0),
|
|
176
|
-
)
|
|
177
|
-
return jax.lax.psum(out, axis_name="model")
|
|
164
|
+
r = _gmm(lhs, rhs, group_sizes)
|
|
165
|
+
return jax.lax.psum(r, axis_name="model")
|
|
178
166
|
|
|
179
167
|
gmm_result = shard_map(
|
|
180
168
|
_gmm_all_reduce,
|
|
181
169
|
mesh=mesh,
|
|
182
|
-
in_specs=(P(
|
|
183
|
-
out_specs=(P(
|
|
184
|
-
|
|
170
|
+
in_specs=(P(None, "model"), P(None, None, "model"), P()),
|
|
171
|
+
out_specs=(P()),
|
|
172
|
+
check_rep=False,
|
|
185
173
|
)(lhs, rhs, group_sizes)
|
|
186
174
|
|
|
187
175
|
if rhs_bias is not None:
|
|
176
|
+
rhs_bias = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
|
|
177
|
+
gmm_result = (gmm_result + rhs_bias).astype(gmm_result.dtype)
|
|
188
178
|
|
|
189
|
-
|
|
190
|
-
rhs_bias = jnp.repeat(
|
|
191
|
-
rhs_bias_local,
|
|
192
|
-
group_sizes_global,
|
|
193
|
-
0,
|
|
194
|
-
total_repeat_length=gmm_result_local.shape[0])
|
|
195
|
-
return gmm_result_local + rhs_bias
|
|
196
|
-
|
|
197
|
-
gmm_result = shard_map(
|
|
198
|
-
_add_bias,
|
|
199
|
-
mesh=mesh,
|
|
200
|
-
in_specs=(P("data"), P(), P("data")),
|
|
201
|
-
out_specs=(P("data")),
|
|
202
|
-
)(gmm_result, rhs_bias, group_sizes)
|
|
203
|
-
|
|
204
|
-
return gmm_result.astype(lhs.dtype)
|
|
179
|
+
return gmm_result
|
|
205
180
|
|
|
206
181
|
|
|
207
182
|
def expert_sharded_gmm(
|
|
208
183
|
lhs: jax.Array,
|
|
209
184
|
rhs: jax.Array,
|
|
210
185
|
group_sizes: jax.Array,
|
|
186
|
+
transpose_rhs: bool,
|
|
211
187
|
mesh: Mesh,
|
|
188
|
+
num_experts: int,
|
|
189
|
+
ep_size: int,
|
|
212
190
|
) -> jax.Array:
|
|
213
|
-
|
|
191
|
+
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
|
|
192
|
+
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
|
|
193
|
+
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
|
|
194
|
+
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
|
|
214
195
|
|
|
215
|
-
num_experts = rhs.shape[0]
|
|
216
196
|
num_experts_per_shard = num_experts // ep_size
|
|
217
197
|
group_offset = jnp.arange(0, num_experts, num_experts_per_shard)
|
|
198
|
+
group_offset = jax.lax.with_sharding_constraint(
|
|
199
|
+
group_offset, NamedSharding(mesh, P("model")))
|
|
218
200
|
|
|
219
201
|
def _gmm(lhs, rhs, group_sizes, group_offset):
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
202
|
+
# Group offset for this shard. `group_offset` is sharded, and in this
|
|
203
|
+
# sharded function, it has only 1 element and `group_offset.shape` is
|
|
204
|
+
# (1,) but gmm kernel requires the group_offset to be a ()-shaped array,
|
|
205
|
+
# so we group_offset[0].
|
|
206
|
+
group_offset_of_shard = group_offset[0]
|
|
223
207
|
gmm_res = gmm(
|
|
224
208
|
lhs=lhs,
|
|
225
209
|
rhs=rhs,
|
|
226
210
|
group_sizes=group_sizes,
|
|
227
211
|
preferred_element_type=lhs.dtype,
|
|
228
212
|
tiling=(tm, tk, tn),
|
|
229
|
-
transpose_rhs=
|
|
230
|
-
group_offset=
|
|
213
|
+
transpose_rhs=transpose_rhs,
|
|
214
|
+
group_offset=group_offset_of_shard,
|
|
231
215
|
)
|
|
232
216
|
return gmm_res
|
|
233
217
|
|
|
@@ -254,24 +238,30 @@ def expert_sharded_gmm(
|
|
|
254
238
|
mesh=mesh,
|
|
255
239
|
in_specs=(P(), P("model", None, None), P(), P("model")),
|
|
256
240
|
out_specs=(P("model", None)),
|
|
257
|
-
|
|
241
|
+
check_rep=False,
|
|
258
242
|
)(lhs, rhs, group_sizes, group_offset)
|
|
259
243
|
|
|
260
244
|
# For i-th shard, it is responsible groups (AKA experts) from
|
|
261
245
|
# i*num_experts_per_shard to (i+1)*num_experts_per_shard We sum them up to
|
|
262
246
|
# get total rows in that shard, and that is the size for shard to send to
|
|
263
247
|
# its peers. This is also the number of non-zero rows from the gmm results.
|
|
264
|
-
# In the working example, send_sizes would be [3, 2, 5, 4]
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
send_sizes = group_sizes.reshape(-1, num_experts_per_shard).sum(axis=1)
|
|
248
|
+
# In the working example, send_sizes would be [3, 2, 5, 4]
|
|
249
|
+
send_sizes = jnp.array([
|
|
250
|
+
group_sizes[i * num_experts_per_shard:(i + 1) *
|
|
251
|
+
num_experts_per_shard].sum() for i in range(ep_size)
|
|
252
|
+
])
|
|
270
253
|
# In the working example, input_offsets would be [0, 3, 5, 10]
|
|
271
254
|
input_offsets = jnp.concatenate((jnp.array([0]), send_sizes.cumsum()[:-1]))
|
|
272
255
|
output_offsets = input_offsets
|
|
273
256
|
recv_sizes = send_sizes
|
|
274
257
|
|
|
258
|
+
input_offsets = jax.lax.with_sharding_constraint(
|
|
259
|
+
input_offsets, NamedSharding(mesh, P("model")))
|
|
260
|
+
send_sizes = jax.lax.with_sharding_constraint(
|
|
261
|
+
send_sizes, NamedSharding(mesh, P("model")))
|
|
262
|
+
output_offsets = jax.lax.with_sharding_constraint(
|
|
263
|
+
output_offsets, NamedSharding(mesh, P("model")))
|
|
264
|
+
|
|
275
265
|
def _ragged_all_to_all(operand, input_offsets, send_sizes, output_offsets,
|
|
276
266
|
recv_sizes):
|
|
277
267
|
output = jnp.zeros_like(operand)
|
|
@@ -326,20 +316,10 @@ def expert_sharded_gmm(
|
|
|
326
316
|
mesh=mesh,
|
|
327
317
|
in_specs=(P("model", None), P("model"), P("model"), P("model"), P()),
|
|
328
318
|
out_specs=(P()),
|
|
329
|
-
|
|
319
|
+
check_rep=False,
|
|
330
320
|
)(gmm_res, input_offsets, send_sizes, output_offsets, recv_sizes)
|
|
331
321
|
|
|
332
322
|
|
|
333
|
-
@functools.partial(
|
|
334
|
-
jax.jit,
|
|
335
|
-
static_argnames=(
|
|
336
|
-
"topk",
|
|
337
|
-
"renormalize",
|
|
338
|
-
"mesh",
|
|
339
|
-
"use_ep",
|
|
340
|
-
"activation",
|
|
341
|
-
),
|
|
342
|
-
)
|
|
343
323
|
def fused_moe_func(
|
|
344
324
|
hidden_states: jax.Array,
|
|
345
325
|
w1: jax.Array,
|
|
@@ -348,45 +328,37 @@ def fused_moe_func(
|
|
|
348
328
|
w2_bias: jax.Array | None,
|
|
349
329
|
gating_output: jax.Array,
|
|
350
330
|
topk: int,
|
|
331
|
+
global_num_experts: int,
|
|
351
332
|
renormalize: bool,
|
|
333
|
+
reduce_results: bool,
|
|
352
334
|
mesh: Mesh,
|
|
353
335
|
use_ep: bool,
|
|
354
336
|
activation: str,
|
|
355
|
-
)
|
|
337
|
+
):
|
|
356
338
|
"""
|
|
357
|
-
Route tokens in hidden_states into each experts based on routing
|
|
358
|
-
information in gating_out and performs moe with w1 and w2 weights.
|
|
359
|
-
|
|
360
339
|
Args:
|
|
361
|
-
hidden_states: [
|
|
362
|
-
w1:
|
|
363
|
-
w2:
|
|
364
|
-
|
|
365
|
-
w2_bias: optional bias of w2 [num_experts, hidden_size]
|
|
366
|
-
gating_output: routing information of tokens [num_tokens, num_experts]
|
|
367
|
-
topk: number of experts to choose per token.
|
|
368
|
-
renormalize: normalize gating_output.
|
|
369
|
-
mesh: mesh to perform moe.
|
|
370
|
-
use_ep: use expert parallelism.
|
|
371
|
-
activation: activation function to perform on the output of w1.
|
|
372
|
-
|
|
373
|
-
Returns:
|
|
374
|
-
Output of moe operation [num_tokens, hidden_size]
|
|
340
|
+
hidden_states: [*, hidden_size]
|
|
341
|
+
w1: [num_experts, intermediate_size * 2, hidden_size]
|
|
342
|
+
w2: [num_experts, hidden_size, intermediate_size]
|
|
343
|
+
gating_output: [*, num_experts]
|
|
375
344
|
"""
|
|
345
|
+
# adapted from https://github.com/vllm-project/vllm/blob/29fa5cac1cd731026f59084d93a822921507573c/vllm/model_executor/layers/fused_moe/moe_pallas.py#L26
|
|
376
346
|
if use_ep and (w1_bias is not None or w2_bias is not None):
|
|
377
347
|
raise NotImplementedError(
|
|
378
348
|
"Bias is not supported when using expert parallelism.")
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
349
|
+
orig_shape = hidden_states.shape
|
|
350
|
+
hidden_size = hidden_states.shape[-1]
|
|
351
|
+
num_tokens = hidden_states.size // hidden_size
|
|
352
|
+
assert global_num_experts == w1.shape[0]
|
|
353
|
+
ep_size = mesh.shape["model"] # only used if use_ep is True.
|
|
354
|
+
intermediate_size = w2.shape[-1]
|
|
382
355
|
dtype = hidden_states.dtype
|
|
383
|
-
|
|
384
356
|
assert (num_tokens * topk) % 16 == 0, (
|
|
385
357
|
"The kernel requires num_tokens * topk to be a multiple of "
|
|
386
358
|
f"16 but got {num_tokens}*{topk}={num_tokens*topk}")
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
359
|
+
|
|
360
|
+
hidden_states = hidden_states.reshape(num_tokens, hidden_size)
|
|
361
|
+
gating_output = gating_output.reshape(num_tokens, global_num_experts)
|
|
390
362
|
|
|
391
363
|
topk_weights = jax.nn.softmax(gating_output.astype(jnp.float32), axis=-1)
|
|
392
364
|
topk_weights, topk_indices = jax.lax.top_k(topk_weights, k=topk)
|
|
@@ -394,76 +366,142 @@ def fused_moe_func(
|
|
|
394
366
|
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdims=True)
|
|
395
367
|
topk_weights = topk_weights.astype(dtype)
|
|
396
368
|
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
group_sizes_local = jnp.bincount(topk_indices_flat,
|
|
406
|
-
length=global_num_experts)
|
|
407
|
-
|
|
408
|
-
x = hidden_states_local[token_indices_sorted]
|
|
409
|
-
return x, group_sizes_local, topk_argsort_revert_indices
|
|
410
|
-
|
|
411
|
-
x, group_sizes, topk_argsort_revert_indices = shard_map(
|
|
412
|
-
_process_tokens_locally,
|
|
413
|
-
mesh=mesh,
|
|
414
|
-
in_specs=(P("data", None), P("data", None)),
|
|
415
|
-
out_specs=(P("data", None), P("data"), P("data")),
|
|
416
|
-
)(hidden_states, topk_indices)
|
|
369
|
+
topk_indices_flat = topk_indices.flatten()
|
|
370
|
+
topk_argsort_indices = jnp.argsort(topk_indices_flat)
|
|
371
|
+
topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
|
|
372
|
+
token_indices = jnp.arange(num_tokens, dtype=jnp.int32).repeat(topk)
|
|
373
|
+
token_indices_sorted = token_indices[topk_argsort_indices]
|
|
374
|
+
group_sizes = jnp.bincount(topk_indices_flat, length=global_num_experts)
|
|
375
|
+
|
|
376
|
+
x = hidden_states[token_indices_sorted]
|
|
417
377
|
|
|
418
378
|
if use_ep:
|
|
419
379
|
x = expert_sharded_gmm(
|
|
420
380
|
x,
|
|
421
381
|
w1,
|
|
422
382
|
group_sizes,
|
|
383
|
+
transpose_rhs=True,
|
|
423
384
|
mesh=mesh,
|
|
385
|
+
num_experts=global_num_experts,
|
|
386
|
+
ep_size=ep_size,
|
|
424
387
|
)
|
|
425
|
-
x1, x2 =
|
|
426
|
-
|
|
427
|
-
x = activation_fn(activation, x1, x2)
|
|
428
|
-
|
|
429
|
-
x = expert_sharded_gmm(
|
|
430
|
-
x,
|
|
431
|
-
w2,
|
|
432
|
-
group_sizes,
|
|
433
|
-
mesh=mesh,
|
|
434
|
-
)
|
|
388
|
+
x1, x2 = x[..., :intermediate_size], x[..., intermediate_size:]
|
|
435
389
|
else:
|
|
436
390
|
x1, x2 = tensor_sharded_gmm_merged_column_parallel(
|
|
437
391
|
x,
|
|
438
392
|
w1,
|
|
439
393
|
w1_bias,
|
|
440
394
|
group_sizes,
|
|
395
|
+
transpose_rhs=True,
|
|
441
396
|
mesh=mesh,
|
|
397
|
+
intermediate_size=intermediate_size,
|
|
442
398
|
)
|
|
443
399
|
|
|
444
|
-
|
|
400
|
+
x = activation_fn(activation, x1, x2)
|
|
445
401
|
|
|
402
|
+
if use_ep:
|
|
403
|
+
x = expert_sharded_gmm(
|
|
404
|
+
x,
|
|
405
|
+
w2,
|
|
406
|
+
group_sizes,
|
|
407
|
+
transpose_rhs=True,
|
|
408
|
+
mesh=mesh,
|
|
409
|
+
num_experts=global_num_experts,
|
|
410
|
+
ep_size=ep_size,
|
|
411
|
+
)
|
|
412
|
+
else:
|
|
413
|
+
x = jax.lax.with_sharding_constraint(
|
|
414
|
+
x, NamedSharding(mesh, P(None, "model")))
|
|
446
415
|
x = tensor_sharded_gmm_row_parallel(
|
|
447
416
|
x,
|
|
448
417
|
w2,
|
|
449
418
|
w2_bias,
|
|
450
419
|
group_sizes,
|
|
420
|
+
transpose_rhs=True,
|
|
451
421
|
mesh=mesh,
|
|
452
422
|
)
|
|
453
423
|
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
x_local = x_local * jnp.expand_dims(topk_weights_local, axis=-1)
|
|
459
|
-
x_local = x_local.sum(axis=-2)
|
|
460
|
-
return x_local
|
|
424
|
+
x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
|
|
425
|
+
x = x * jnp.expand_dims(topk_weights, axis=-1)
|
|
426
|
+
x = x.sum(axis=-2)
|
|
427
|
+
x = x.reshape(orig_shape)
|
|
461
428
|
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
in_specs=(P("data", None), P("data"), P("data", None)),
|
|
466
|
-
out_specs=(P("data", None)),
|
|
467
|
-
)(x, topk_argsort_revert_indices, topk_weights)
|
|
429
|
+
if reduce_results:
|
|
430
|
+
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
|
|
431
|
+
return x
|
|
468
432
|
|
|
469
|
-
|
|
433
|
+
|
|
434
|
+
@functools.partial(
|
|
435
|
+
jax.jit,
|
|
436
|
+
static_argnames=(
|
|
437
|
+
"topk",
|
|
438
|
+
"global_num_experts",
|
|
439
|
+
"renormalize",
|
|
440
|
+
"reduce_results",
|
|
441
|
+
"mesh",
|
|
442
|
+
"use_ep",
|
|
443
|
+
"activation",
|
|
444
|
+
),
|
|
445
|
+
)
|
|
446
|
+
def fused_moe_func_padded(
|
|
447
|
+
hidden_states: jax.Array,
|
|
448
|
+
w1: jax.Array,
|
|
449
|
+
w2: jax.Array,
|
|
450
|
+
w1_bias: jax.Array | None,
|
|
451
|
+
w2_bias: jax.Array | None,
|
|
452
|
+
gating_output: jax.Array,
|
|
453
|
+
topk: int,
|
|
454
|
+
global_num_experts: int,
|
|
455
|
+
renormalize: bool,
|
|
456
|
+
reduce_results: bool,
|
|
457
|
+
mesh: Mesh,
|
|
458
|
+
use_ep: bool,
|
|
459
|
+
activation: str,
|
|
460
|
+
):
|
|
461
|
+
# TODO(fanhongmin@google.com): Once the jax runner pads the input, we no longer need this.
|
|
462
|
+
hidden_size = hidden_states.shape[-1]
|
|
463
|
+
num_tokens = hidden_states.size // hidden_size
|
|
464
|
+
if num_tokens * topk < 16:
|
|
465
|
+
assert 16 % (num_tokens *
|
|
466
|
+
topk) == 0, f"Cannot pad to 16: {num_tokens=}, {topk=}"
|
|
467
|
+
n_repeats = 16 // (num_tokens * topk)
|
|
468
|
+
|
|
469
|
+
reps = (n_repeats, ) + (1, ) * (hidden_states.ndim - 1)
|
|
470
|
+
expanded_hidden_states = jnp.tile(hidden_states, reps)
|
|
471
|
+
|
|
472
|
+
reps = (n_repeats, ) + (1, ) * (gating_output.ndim - 1)
|
|
473
|
+
expanded_gating_output = jnp.tile(gating_output, reps)
|
|
474
|
+
|
|
475
|
+
expanded_x = fused_moe_func(
|
|
476
|
+
expanded_hidden_states,
|
|
477
|
+
w1,
|
|
478
|
+
w2,
|
|
479
|
+
w1_bias,
|
|
480
|
+
w2_bias,
|
|
481
|
+
expanded_gating_output,
|
|
482
|
+
topk,
|
|
483
|
+
global_num_experts,
|
|
484
|
+
renormalize,
|
|
485
|
+
reduce_results,
|
|
486
|
+
mesh,
|
|
487
|
+
use_ep,
|
|
488
|
+
activation,
|
|
489
|
+
)
|
|
490
|
+
x = expanded_x[:hidden_states.shape[0]]
|
|
491
|
+
return x
|
|
492
|
+
else:
|
|
493
|
+
return fused_moe_func(
|
|
494
|
+
hidden_states,
|
|
495
|
+
w1,
|
|
496
|
+
w2,
|
|
497
|
+
w1_bias,
|
|
498
|
+
w2_bias,
|
|
499
|
+
gating_output,
|
|
500
|
+
topk,
|
|
501
|
+
global_num_experts,
|
|
502
|
+
renormalize,
|
|
503
|
+
reduce_results,
|
|
504
|
+
mesh,
|
|
505
|
+
use_ep,
|
|
506
|
+
activation,
|
|
507
|
+
)
|
|
@@ -61,12 +61,7 @@ class JaxCommonLinearConfig:
|
|
|
61
61
|
" bad performance.", type(layer))
|
|
62
62
|
|
|
63
63
|
self.bias_sharding = P(self.weight_sharding[0])
|
|
64
|
-
|
|
65
|
-
self.n_shards = 1
|
|
66
|
-
for axis in self.weight_sharding[0]:
|
|
67
|
-
self.n_shards *= self.mesh.shape.get(axis, 1)
|
|
68
|
-
else:
|
|
69
|
-
self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
|
|
64
|
+
self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
|
|
70
65
|
|
|
71
66
|
def get_input_sharding(self, x: torchax.tensor.Tensor):
|
|
72
67
|
if self.enable_sequence_parallelism:
|