tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/kernels/mla_v1_test.py +129 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- tests/lora/test_layers.py +4 -7
- tests/lora/test_lora_perf.py +53 -0
- tests/lora/utils.py +0 -8
- tests/test_envs.py +110 -12
- tests/test_quantization.py +3 -0
- tests/test_utils.py +1 -2
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +3 -4
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +93 -9
- tpu_inference/executors/ray_distributed_executor.py +9 -2
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
- tpu_inference/kernels/mla/v1/kernel.py +98 -120
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +11 -7
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +170 -208
- tpu_inference/layers/vllm/linear_common.py +43 -21
- tpu_inference/layers/vllm/quantization/common.py +11 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
- tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
- tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +84 -28
- tpu_inference/models/jax/deepseek_v3.py +185 -64
- tpu_inference/models/jax/gpt_oss.py +3 -3
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
- tpu_inference/models/jax/utils/weight_utils.py +205 -144
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
- tpu_inference/platforms/tpu_platform.py +34 -50
- tpu_inference/runner/compilation_manager.py +144 -60
- tpu_inference/runner/kv_cache.py +40 -20
- tpu_inference/runner/kv_cache_manager.py +48 -33
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +280 -149
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -21
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +46 -18
- tpu_inference/worker/tpu_worker.py +197 -63
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
|
@@ -2,17 +2,16 @@ import functools
|
|
|
2
2
|
|
|
3
3
|
import jax
|
|
4
4
|
from jax import numpy as jnp
|
|
5
|
+
from jax import shard_map
|
|
5
6
|
from jax.experimental.pallas.ops.tpu.megablox.gmm import gmm
|
|
6
|
-
from jax.
|
|
7
|
-
from jax.sharding import
|
|
7
|
+
from jax.sharding import Mesh
|
|
8
|
+
from jax.sharding import PartitionSpec as P
|
|
8
9
|
|
|
9
10
|
from tpu_inference.layers.vllm.linear_common import \
|
|
10
11
|
slice_sharded_tensor_for_concatenation
|
|
11
12
|
|
|
12
|
-
P = PartitionSpec
|
|
13
13
|
|
|
14
|
-
|
|
15
|
-
def activation_fn(activation: str, x1, x2):
|
|
14
|
+
def activation_fn(activation: str, x1: jax.Array, x2: jax.Array) -> jax.Array:
|
|
16
15
|
match activation:
|
|
17
16
|
case "silu":
|
|
18
17
|
return jax.nn.silu(x1) * x2
|
|
@@ -23,7 +22,10 @@ def activation_fn(activation: str, x1, x2):
|
|
|
23
22
|
f"FusedMoE does not support {activation} activation")
|
|
24
23
|
|
|
25
24
|
|
|
26
|
-
def _swigluoai(x1
|
|
25
|
+
def _swigluoai(x1: jax.Array,
|
|
26
|
+
x2: jax.Array,
|
|
27
|
+
alpha=1.702,
|
|
28
|
+
limit=7.0) -> jax.Array:
|
|
27
29
|
x1 = jnp.clip(x1, a_max=limit)
|
|
28
30
|
x2 = jnp.clip(x2, a_min=-limit, a_max=limit)
|
|
29
31
|
|
|
@@ -103,40 +105,53 @@ def tensor_sharded_gmm_merged_column_parallel(
|
|
|
103
105
|
rhs: jax.Array,
|
|
104
106
|
rhs_bias: jax.Array | None,
|
|
105
107
|
group_sizes: jax.Array,
|
|
106
|
-
transpose_rhs: bool,
|
|
107
108
|
mesh: Mesh,
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
109
|
+
) -> tuple[jax.Array, jax.Array]:
|
|
110
|
+
|
|
111
|
+
def _gmm(lhs, rhs, group_sizes):
|
|
112
|
+
m, g, n, k = lhs.shape[0], *rhs.shape
|
|
113
|
+
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
|
|
114
|
+
return gmm(
|
|
115
|
+
lhs,
|
|
116
|
+
rhs,
|
|
117
|
+
group_sizes,
|
|
118
|
+
preferred_element_type=lhs.dtype,
|
|
119
|
+
tiling=(tm, tk, tn),
|
|
120
|
+
transpose_rhs=True,
|
|
121
|
+
group_offset=jnp.array(0),
|
|
122
|
+
)
|
|
122
123
|
|
|
123
124
|
gmm_result = shard_map(
|
|
124
125
|
_gmm,
|
|
125
126
|
mesh=mesh,
|
|
126
|
-
in_specs=(P(), P(None, "model", None), P()),
|
|
127
|
-
out_specs=(P(
|
|
128
|
-
|
|
127
|
+
in_specs=(P("data", None), P(None, "model", None), P("data")),
|
|
128
|
+
out_specs=(P("data", "model")),
|
|
129
|
+
check_vma=False,
|
|
129
130
|
)(lhs, rhs, group_sizes)
|
|
130
131
|
|
|
131
132
|
if rhs_bias is not None:
|
|
132
|
-
rhs_bis = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
|
|
133
|
-
gmm_result = (gmm_result + rhs_bis).astype(gmm_result.dtype)
|
|
134
133
|
|
|
135
|
-
|
|
136
|
-
|
|
134
|
+
def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
|
|
135
|
+
rhs_bias = jnp.repeat(
|
|
136
|
+
rhs_bias_local,
|
|
137
|
+
group_sizes_global,
|
|
138
|
+
0,
|
|
139
|
+
total_repeat_length=gmm_result_local.shape[0])
|
|
140
|
+
return gmm_result_local + rhs_bias
|
|
141
|
+
|
|
142
|
+
gmm_result = shard_map(
|
|
143
|
+
_add_bias,
|
|
144
|
+
mesh=mesh,
|
|
145
|
+
in_specs=(P("data", "model"), P(None, "model"), P("data")),
|
|
146
|
+
out_specs=(P("data", "model")),
|
|
147
|
+
)(gmm_result, rhs_bias, group_sizes)
|
|
148
|
+
gmm_result = gmm_result.astype(lhs.dtype)
|
|
137
149
|
|
|
150
|
+
tp_size = mesh.shape["model"]
|
|
151
|
+
intermediate_size = gmm_result.shape[-1] // 2
|
|
152
|
+
output_sizes = [intermediate_size, intermediate_size]
|
|
138
153
|
return slice_sharded_tensor_for_concatenation(gmm_result, output_sizes,
|
|
139
|
-
|
|
154
|
+
tp_size)
|
|
140
155
|
|
|
141
156
|
|
|
142
157
|
def tensor_sharded_gmm_row_parallel(
|
|
@@ -144,74 +159,75 @@ def tensor_sharded_gmm_row_parallel(
|
|
|
144
159
|
rhs: jax.Array,
|
|
145
160
|
rhs_bias: jax.Array | None,
|
|
146
161
|
group_sizes: jax.Array,
|
|
147
|
-
transpose_rhs: bool,
|
|
148
162
|
mesh: Mesh,
|
|
149
163
|
) -> jax.Array:
|
|
150
|
-
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
|
|
151
|
-
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
|
|
152
|
-
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
|
|
153
|
-
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
|
|
154
|
-
|
|
155
|
-
_gmm = functools.partial(
|
|
156
|
-
gmm,
|
|
157
|
-
preferred_element_type=lhs.dtype,
|
|
158
|
-
tiling=(tm, tk, tn),
|
|
159
|
-
transpose_rhs=transpose_rhs,
|
|
160
|
-
group_offset=jnp.array(0),
|
|
161
|
-
)
|
|
162
164
|
|
|
163
165
|
def _gmm_all_reduce(lhs, rhs, group_sizes):
|
|
164
|
-
|
|
165
|
-
|
|
166
|
+
m, g, n, k = lhs.shape[0], *rhs.shape
|
|
167
|
+
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
|
|
168
|
+
out = gmm(
|
|
169
|
+
lhs,
|
|
170
|
+
rhs,
|
|
171
|
+
group_sizes,
|
|
172
|
+
preferred_element_type=lhs.dtype,
|
|
173
|
+
tiling=(tm, tk, tn),
|
|
174
|
+
transpose_rhs=True,
|
|
175
|
+
group_offset=jnp.array(0),
|
|
176
|
+
)
|
|
177
|
+
return jax.lax.psum(out, axis_name="model")
|
|
166
178
|
|
|
167
179
|
gmm_result = shard_map(
|
|
168
180
|
_gmm_all_reduce,
|
|
169
181
|
mesh=mesh,
|
|
170
|
-
in_specs=(P(
|
|
171
|
-
out_specs=(P()),
|
|
172
|
-
|
|
182
|
+
in_specs=(P("data", "model"), P(None, None, "model"), P("data")),
|
|
183
|
+
out_specs=(P("data")),
|
|
184
|
+
check_vma=False,
|
|
173
185
|
)(lhs, rhs, group_sizes)
|
|
174
186
|
|
|
175
187
|
if rhs_bias is not None:
|
|
176
|
-
rhs_bias = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
|
|
177
|
-
gmm_result = (gmm_result + rhs_bias).astype(gmm_result.dtype)
|
|
178
188
|
|
|
179
|
-
|
|
189
|
+
def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
|
|
190
|
+
rhs_bias = jnp.repeat(
|
|
191
|
+
rhs_bias_local,
|
|
192
|
+
group_sizes_global,
|
|
193
|
+
0,
|
|
194
|
+
total_repeat_length=gmm_result_local.shape[0])
|
|
195
|
+
return gmm_result_local + rhs_bias
|
|
196
|
+
|
|
197
|
+
gmm_result = shard_map(
|
|
198
|
+
_add_bias,
|
|
199
|
+
mesh=mesh,
|
|
200
|
+
in_specs=(P("data"), P(), P("data")),
|
|
201
|
+
out_specs=(P("data")),
|
|
202
|
+
)(gmm_result, rhs_bias, group_sizes)
|
|
203
|
+
|
|
204
|
+
return gmm_result.astype(lhs.dtype)
|
|
180
205
|
|
|
181
206
|
|
|
182
207
|
def expert_sharded_gmm(
|
|
183
208
|
lhs: jax.Array,
|
|
184
209
|
rhs: jax.Array,
|
|
185
210
|
group_sizes: jax.Array,
|
|
186
|
-
transpose_rhs: bool,
|
|
187
211
|
mesh: Mesh,
|
|
188
|
-
num_experts: int,
|
|
189
|
-
ep_size: int,
|
|
190
212
|
) -> jax.Array:
|
|
191
|
-
|
|
192
|
-
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
|
|
193
|
-
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
|
|
194
|
-
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
|
|
213
|
+
ep_size = mesh.shape["model"]
|
|
195
214
|
|
|
215
|
+
num_experts = rhs.shape[0]
|
|
196
216
|
num_experts_per_shard = num_experts // ep_size
|
|
197
217
|
group_offset = jnp.arange(0, num_experts, num_experts_per_shard)
|
|
198
|
-
group_offset = jax.lax.with_sharding_constraint(
|
|
199
|
-
group_offset, NamedSharding(mesh, P("model")))
|
|
200
218
|
|
|
201
219
|
def _gmm(lhs, rhs, group_sizes, group_offset):
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
# so we group_offset[0].
|
|
206
|
-
group_offset_of_shard = group_offset[0]
|
|
220
|
+
m, g, n, k = lhs.shape[0], *rhs.shape
|
|
221
|
+
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
|
|
222
|
+
|
|
207
223
|
gmm_res = gmm(
|
|
208
224
|
lhs=lhs,
|
|
209
225
|
rhs=rhs,
|
|
210
226
|
group_sizes=group_sizes,
|
|
211
227
|
preferred_element_type=lhs.dtype,
|
|
212
228
|
tiling=(tm, tk, tn),
|
|
213
|
-
transpose_rhs=
|
|
214
|
-
group_offset=
|
|
229
|
+
transpose_rhs=True,
|
|
230
|
+
group_offset=group_offset[0],
|
|
215
231
|
)
|
|
216
232
|
return gmm_res
|
|
217
233
|
|
|
@@ -238,30 +254,24 @@ def expert_sharded_gmm(
|
|
|
238
254
|
mesh=mesh,
|
|
239
255
|
in_specs=(P(), P("model", None, None), P(), P("model")),
|
|
240
256
|
out_specs=(P("model", None)),
|
|
241
|
-
|
|
257
|
+
check_vma=False,
|
|
242
258
|
)(lhs, rhs, group_sizes, group_offset)
|
|
243
259
|
|
|
244
260
|
# For i-th shard, it is responsible groups (AKA experts) from
|
|
245
261
|
# i*num_experts_per_shard to (i+1)*num_experts_per_shard We sum them up to
|
|
246
262
|
# get total rows in that shard, and that is the size for shard to send to
|
|
247
263
|
# its peers. This is also the number of non-zero rows from the gmm results.
|
|
248
|
-
# In the working example, send_sizes would be [3, 2, 5, 4]
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
]
|
|
264
|
+
# In the working example, send_sizes would be [3, 2, 5, 4].
|
|
265
|
+
|
|
266
|
+
# group_sizes has shape of [num_tokens_per_shard * num_experts_per_shard].
|
|
267
|
+
# So reshaping to [num_tokens_per_shard, num_experts_per_shard] and applying
|
|
268
|
+
# sum(axis=1) will get desired send_sizes shaped [num_tokens_per_shard].
|
|
269
|
+
send_sizes = group_sizes.reshape(-1, num_experts_per_shard).sum(axis=1)
|
|
253
270
|
# In the working example, input_offsets would be [0, 3, 5, 10]
|
|
254
271
|
input_offsets = jnp.concatenate((jnp.array([0]), send_sizes.cumsum()[:-1]))
|
|
255
272
|
output_offsets = input_offsets
|
|
256
273
|
recv_sizes = send_sizes
|
|
257
274
|
|
|
258
|
-
input_offsets = jax.lax.with_sharding_constraint(
|
|
259
|
-
input_offsets, NamedSharding(mesh, P("model")))
|
|
260
|
-
send_sizes = jax.lax.with_sharding_constraint(
|
|
261
|
-
send_sizes, NamedSharding(mesh, P("model")))
|
|
262
|
-
output_offsets = jax.lax.with_sharding_constraint(
|
|
263
|
-
output_offsets, NamedSharding(mesh, P("model")))
|
|
264
|
-
|
|
265
275
|
def _ragged_all_to_all(operand, input_offsets, send_sizes, output_offsets,
|
|
266
276
|
recv_sizes):
|
|
267
277
|
output = jnp.zeros_like(operand)
|
|
@@ -316,10 +326,20 @@ def expert_sharded_gmm(
|
|
|
316
326
|
mesh=mesh,
|
|
317
327
|
in_specs=(P("model", None), P("model"), P("model"), P("model"), P()),
|
|
318
328
|
out_specs=(P()),
|
|
319
|
-
|
|
329
|
+
check_vma=False,
|
|
320
330
|
)(gmm_res, input_offsets, send_sizes, output_offsets, recv_sizes)
|
|
321
331
|
|
|
322
332
|
|
|
333
|
+
@functools.partial(
|
|
334
|
+
jax.jit,
|
|
335
|
+
static_argnames=(
|
|
336
|
+
"topk",
|
|
337
|
+
"renormalize",
|
|
338
|
+
"mesh",
|
|
339
|
+
"use_ep",
|
|
340
|
+
"activation",
|
|
341
|
+
),
|
|
342
|
+
)
|
|
323
343
|
def fused_moe_func(
|
|
324
344
|
hidden_states: jax.Array,
|
|
325
345
|
w1: jax.Array,
|
|
@@ -328,37 +348,45 @@ def fused_moe_func(
|
|
|
328
348
|
w2_bias: jax.Array | None,
|
|
329
349
|
gating_output: jax.Array,
|
|
330
350
|
topk: int,
|
|
331
|
-
global_num_experts: int,
|
|
332
351
|
renormalize: bool,
|
|
333
|
-
reduce_results: bool,
|
|
334
352
|
mesh: Mesh,
|
|
335
353
|
use_ep: bool,
|
|
336
354
|
activation: str,
|
|
337
|
-
):
|
|
355
|
+
) -> jax.Array:
|
|
338
356
|
"""
|
|
357
|
+
Route tokens in hidden_states into each experts based on routing
|
|
358
|
+
information in gating_out and performs moe with w1 and w2 weights.
|
|
359
|
+
|
|
339
360
|
Args:
|
|
340
|
-
hidden_states: [
|
|
341
|
-
w1: [num_experts, intermediate_size * 2, hidden_size]
|
|
342
|
-
w2: [num_experts, hidden_size, intermediate_size]
|
|
343
|
-
|
|
361
|
+
hidden_states: [num_tokens, hidden_size]
|
|
362
|
+
w1: first moe weights [num_experts, intermediate_size * 2, hidden_size]
|
|
363
|
+
w2: second moe weights [num_experts, hidden_size, intermediate_size]
|
|
364
|
+
w1_bias: optional bias of w1 [num_experts, intermediate_size * 2]
|
|
365
|
+
w2_bias: optional bias of w2 [num_experts, hidden_size]
|
|
366
|
+
gating_output: routing information of tokens [num_tokens, num_experts]
|
|
367
|
+
topk: number of experts to choose per token.
|
|
368
|
+
renormalize: normalize gating_output.
|
|
369
|
+
mesh: mesh to perform moe.
|
|
370
|
+
use_ep: use expert parallelism.
|
|
371
|
+
activation: activation function to perform on the output of w1.
|
|
372
|
+
|
|
373
|
+
Returns:
|
|
374
|
+
Output of moe operation [num_tokens, hidden_size]
|
|
344
375
|
"""
|
|
345
|
-
# adapted from https://github.com/vllm-project/vllm/blob/29fa5cac1cd731026f59084d93a822921507573c/vllm/model_executor/layers/fused_moe/moe_pallas.py#L26
|
|
346
376
|
if use_ep and (w1_bias is not None or w2_bias is not None):
|
|
347
377
|
raise NotImplementedError(
|
|
348
378
|
"Bias is not supported when using expert parallelism.")
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
assert global_num_experts == w1.shape[0]
|
|
353
|
-
ep_size = mesh.shape["model"] # only used if use_ep is True.
|
|
354
|
-
intermediate_size = w2.shape[-1]
|
|
379
|
+
|
|
380
|
+
num_tokens = hidden_states.shape[0]
|
|
381
|
+
global_num_experts, hidden_size, intermediate_size = w2.shape
|
|
355
382
|
dtype = hidden_states.dtype
|
|
383
|
+
|
|
356
384
|
assert (num_tokens * topk) % 16 == 0, (
|
|
357
385
|
"The kernel requires num_tokens * topk to be a multiple of "
|
|
358
386
|
f"16 but got {num_tokens}*{topk}={num_tokens*topk}")
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
387
|
+
assert hidden_states.shape == (num_tokens, hidden_size)
|
|
388
|
+
assert gating_output.shape == (num_tokens, global_num_experts)
|
|
389
|
+
assert w1.shape == (global_num_experts, intermediate_size * 2, hidden_size)
|
|
362
390
|
|
|
363
391
|
topk_weights = jax.nn.softmax(gating_output.astype(jnp.float32), axis=-1)
|
|
364
392
|
topk_weights, topk_indices = jax.lax.top_k(topk_weights, k=topk)
|
|
@@ -366,142 +394,76 @@ def fused_moe_func(
|
|
|
366
394
|
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdims=True)
|
|
367
395
|
topk_weights = topk_weights.astype(dtype)
|
|
368
396
|
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
397
|
+
def _process_tokens_locally(hidden_states_local, topk_indices_local):
|
|
398
|
+
num_tokens_local = hidden_states_local.shape[0]
|
|
399
|
+
topk_indices_flat = topk_indices_local.flatten()
|
|
400
|
+
topk_argsort_indices = jnp.argsort(topk_indices_flat)
|
|
401
|
+
topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
|
|
402
|
+
token_indices = jnp.arange(num_tokens_local,
|
|
403
|
+
dtype=jnp.int32).repeat(topk)
|
|
404
|
+
token_indices_sorted = token_indices[topk_argsort_indices]
|
|
405
|
+
group_sizes_local = jnp.bincount(topk_indices_flat,
|
|
406
|
+
length=global_num_experts)
|
|
407
|
+
|
|
408
|
+
x = hidden_states_local[token_indices_sorted]
|
|
409
|
+
return x, group_sizes_local, topk_argsort_revert_indices
|
|
410
|
+
|
|
411
|
+
x, group_sizes, topk_argsort_revert_indices = shard_map(
|
|
412
|
+
_process_tokens_locally,
|
|
413
|
+
mesh=mesh,
|
|
414
|
+
in_specs=(P("data", None), P("data", None)),
|
|
415
|
+
out_specs=(P("data", None), P("data"), P("data")),
|
|
416
|
+
)(hidden_states, topk_indices)
|
|
377
417
|
|
|
378
418
|
if use_ep:
|
|
379
419
|
x = expert_sharded_gmm(
|
|
380
420
|
x,
|
|
381
421
|
w1,
|
|
382
422
|
group_sizes,
|
|
383
|
-
transpose_rhs=True,
|
|
384
423
|
mesh=mesh,
|
|
385
|
-
num_experts=global_num_experts,
|
|
386
|
-
ep_size=ep_size,
|
|
387
424
|
)
|
|
388
|
-
x1, x2 = x
|
|
425
|
+
x1, x2 = jnp.split(x, 2, -1)
|
|
426
|
+
|
|
427
|
+
x = activation_fn(activation, x1, x2)
|
|
428
|
+
|
|
429
|
+
x = expert_sharded_gmm(
|
|
430
|
+
x,
|
|
431
|
+
w2,
|
|
432
|
+
group_sizes,
|
|
433
|
+
mesh=mesh,
|
|
434
|
+
)
|
|
389
435
|
else:
|
|
390
436
|
x1, x2 = tensor_sharded_gmm_merged_column_parallel(
|
|
391
437
|
x,
|
|
392
438
|
w1,
|
|
393
439
|
w1_bias,
|
|
394
440
|
group_sizes,
|
|
395
|
-
transpose_rhs=True,
|
|
396
441
|
mesh=mesh,
|
|
397
|
-
intermediate_size=intermediate_size,
|
|
398
442
|
)
|
|
399
443
|
|
|
400
|
-
|
|
444
|
+
x = activation_fn(activation, x1, x2)
|
|
401
445
|
|
|
402
|
-
if use_ep:
|
|
403
|
-
x = expert_sharded_gmm(
|
|
404
|
-
x,
|
|
405
|
-
w2,
|
|
406
|
-
group_sizes,
|
|
407
|
-
transpose_rhs=True,
|
|
408
|
-
mesh=mesh,
|
|
409
|
-
num_experts=global_num_experts,
|
|
410
|
-
ep_size=ep_size,
|
|
411
|
-
)
|
|
412
|
-
else:
|
|
413
|
-
x = jax.lax.with_sharding_constraint(
|
|
414
|
-
x, NamedSharding(mesh, P(None, "model")))
|
|
415
446
|
x = tensor_sharded_gmm_row_parallel(
|
|
416
447
|
x,
|
|
417
448
|
w2,
|
|
418
449
|
w2_bias,
|
|
419
450
|
group_sizes,
|
|
420
|
-
transpose_rhs=True,
|
|
421
451
|
mesh=mesh,
|
|
422
452
|
)
|
|
423
453
|
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
return x
|
|
454
|
+
def _finalize_output(x_local, topk_argsort_revert_indices_local,
|
|
455
|
+
topk_weights_local):
|
|
456
|
+
x_local = x_local[topk_argsort_revert_indices_local].reshape(
|
|
457
|
+
-1, topk, hidden_size)
|
|
458
|
+
x_local = x_local * jnp.expand_dims(topk_weights_local, axis=-1)
|
|
459
|
+
x_local = x_local.sum(axis=-2)
|
|
460
|
+
return x_local
|
|
432
461
|
|
|
462
|
+
x = shard_map(
|
|
463
|
+
_finalize_output,
|
|
464
|
+
mesh=mesh,
|
|
465
|
+
in_specs=(P("data", None), P("data"), P("data", None)),
|
|
466
|
+
out_specs=(P("data", None)),
|
|
467
|
+
)(x, topk_argsort_revert_indices, topk_weights)
|
|
433
468
|
|
|
434
|
-
|
|
435
|
-
jax.jit,
|
|
436
|
-
static_argnames=(
|
|
437
|
-
"topk",
|
|
438
|
-
"global_num_experts",
|
|
439
|
-
"renormalize",
|
|
440
|
-
"reduce_results",
|
|
441
|
-
"mesh",
|
|
442
|
-
"use_ep",
|
|
443
|
-
"activation",
|
|
444
|
-
),
|
|
445
|
-
)
|
|
446
|
-
def fused_moe_func_padded(
|
|
447
|
-
hidden_states: jax.Array,
|
|
448
|
-
w1: jax.Array,
|
|
449
|
-
w2: jax.Array,
|
|
450
|
-
w1_bias: jax.Array | None,
|
|
451
|
-
w2_bias: jax.Array | None,
|
|
452
|
-
gating_output: jax.Array,
|
|
453
|
-
topk: int,
|
|
454
|
-
global_num_experts: int,
|
|
455
|
-
renormalize: bool,
|
|
456
|
-
reduce_results: bool,
|
|
457
|
-
mesh: Mesh,
|
|
458
|
-
use_ep: bool,
|
|
459
|
-
activation: str,
|
|
460
|
-
):
|
|
461
|
-
# TODO(fanhongmin@google.com): Once the jax runner pads the input, we no longer need this.
|
|
462
|
-
hidden_size = hidden_states.shape[-1]
|
|
463
|
-
num_tokens = hidden_states.size // hidden_size
|
|
464
|
-
if num_tokens * topk < 16:
|
|
465
|
-
assert 16 % (num_tokens *
|
|
466
|
-
topk) == 0, f"Cannot pad to 16: {num_tokens=}, {topk=}"
|
|
467
|
-
n_repeats = 16 // (num_tokens * topk)
|
|
468
|
-
|
|
469
|
-
reps = (n_repeats, ) + (1, ) * (hidden_states.ndim - 1)
|
|
470
|
-
expanded_hidden_states = jnp.tile(hidden_states, reps)
|
|
471
|
-
|
|
472
|
-
reps = (n_repeats, ) + (1, ) * (gating_output.ndim - 1)
|
|
473
|
-
expanded_gating_output = jnp.tile(gating_output, reps)
|
|
474
|
-
|
|
475
|
-
expanded_x = fused_moe_func(
|
|
476
|
-
expanded_hidden_states,
|
|
477
|
-
w1,
|
|
478
|
-
w2,
|
|
479
|
-
w1_bias,
|
|
480
|
-
w2_bias,
|
|
481
|
-
expanded_gating_output,
|
|
482
|
-
topk,
|
|
483
|
-
global_num_experts,
|
|
484
|
-
renormalize,
|
|
485
|
-
reduce_results,
|
|
486
|
-
mesh,
|
|
487
|
-
use_ep,
|
|
488
|
-
activation,
|
|
489
|
-
)
|
|
490
|
-
x = expanded_x[:hidden_states.shape[0]]
|
|
491
|
-
return x
|
|
492
|
-
else:
|
|
493
|
-
return fused_moe_func(
|
|
494
|
-
hidden_states,
|
|
495
|
-
w1,
|
|
496
|
-
w2,
|
|
497
|
-
w1_bias,
|
|
498
|
-
w2_bias,
|
|
499
|
-
gating_output,
|
|
500
|
-
topk,
|
|
501
|
-
global_num_experts,
|
|
502
|
-
renormalize,
|
|
503
|
-
reduce_results,
|
|
504
|
-
mesh,
|
|
505
|
-
use_ep,
|
|
506
|
-
activation,
|
|
507
|
-
)
|
|
469
|
+
return x[:num_tokens, :hidden_size]
|
|
@@ -9,30 +9,52 @@ from jax.sharding import PartitionSpec as P
|
|
|
9
9
|
from torchax.interop import torch_view
|
|
10
10
|
from torchax.ops.mappings import t2j
|
|
11
11
|
|
|
12
|
-
from tpu_inference
|
|
13
|
-
|
|
12
|
+
from tpu_inference import envs
|
|
13
|
+
from tpu_inference.kernels.quantized_matmul.kernel import (
|
|
14
|
+
quantized_matmul_kernel, xla_quantized_matmul)
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
def sharded_quantized_matmul(x: jax.Array, w_q: jax.Array, w_s: jax.Array,
|
|
17
|
-
mesh: Mesh, weight_sharding: P):
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
18
|
+
mesh: Mesh, weight_sharding: P) -> jax.Array:
|
|
19
|
+
"""
|
|
20
|
+
Wrapper around the quantized matmul kernel.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
x: Activation.
|
|
24
|
+
w_q: Weight quantized array. [n_output_features, n_input_features]
|
|
25
|
+
w_s: Weight quantization scale. [n_output_features]
|
|
26
|
+
mesh: Mesh to shard on.
|
|
27
|
+
weight_sharding: PartitionSpec for the weight tensor.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
Output of the quantized matmul.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
# NOTE (jacobplatin/kyuyeunk) there have been numeric issues (concerning) NaNs
|
|
34
|
+
# with the kernel and thus we disable it for now.
|
|
35
|
+
if envs.ENABLE_QUANTIZED_MATMUL_KERNEL:
|
|
36
|
+
out_axis, in_axis = weight_sharding
|
|
37
|
+
x_sharding = P(None, in_axis)
|
|
38
|
+
scale_sharding = P(out_axis, )
|
|
39
|
+
out_sharding = P(None, out_axis)
|
|
40
|
+
|
|
41
|
+
x = jax.lax.with_sharding_constraint(x,
|
|
42
|
+
NamedSharding(mesh, x_sharding))
|
|
43
|
+
|
|
44
|
+
def wrapper(x, w_q, w_s):
|
|
45
|
+
output = quantized_matmul_kernel(x, w_q, w_s, x_q_dtype=w_q.dtype)
|
|
46
|
+
if in_axis:
|
|
47
|
+
output = jax.lax.psum(output, axis_name=in_axis)
|
|
48
|
+
return output
|
|
49
|
+
|
|
50
|
+
return shard_map(wrapper,
|
|
51
|
+
mesh=mesh,
|
|
52
|
+
in_specs=(x_sharding, weight_sharding,
|
|
53
|
+
scale_sharding),
|
|
54
|
+
out_specs=(out_sharding),
|
|
55
|
+
check_rep=False)(x, w_q, w_s)
|
|
56
|
+
else:
|
|
57
|
+
return xla_quantized_matmul(x, w_q, w_s)
|
|
36
58
|
|
|
37
59
|
|
|
38
60
|
def reorder_concatenated_tensor_for_sharding(concatenated_tensor: jax.Array,
|
|
@@ -31,17 +31,17 @@ class JaxCommonLinearConfig:
|
|
|
31
31
|
self.output_sizes = [layer.output_size]
|
|
32
32
|
self.weight_sharding = P(None, None)
|
|
33
33
|
self.fuse_matmuls = True
|
|
34
|
-
self.
|
|
34
|
+
self.enable_sp = vllm_config.compilation_config.pass_config.enable_sp
|
|
35
35
|
self.input_sharding = None
|
|
36
36
|
self.output_sharding = None
|
|
37
37
|
|
|
38
38
|
if isinstance(layer, RowParallelLinear):
|
|
39
39
|
self.weight_sharding = P(None, "model")
|
|
40
|
-
if self.
|
|
40
|
+
if self.enable_sp:
|
|
41
41
|
self.output_sharding = P("model", None)
|
|
42
42
|
elif isinstance(layer, ColumnParallelLinear):
|
|
43
43
|
self.weight_sharding = P("model", None)
|
|
44
|
-
if self.
|
|
44
|
+
if self.enable_sp:
|
|
45
45
|
self.input_sharding = P("model", None)
|
|
46
46
|
|
|
47
47
|
if isinstance(layer, MergedColumnParallelLinear) or isinstance(
|
|
@@ -61,10 +61,15 @@ class JaxCommonLinearConfig:
|
|
|
61
61
|
" bad performance.", type(layer))
|
|
62
62
|
|
|
63
63
|
self.bias_sharding = P(self.weight_sharding[0])
|
|
64
|
-
|
|
64
|
+
if isinstance(self.weight_sharding[0], tuple):
|
|
65
|
+
self.n_shards = 1
|
|
66
|
+
for axis in self.weight_sharding[0]:
|
|
67
|
+
self.n_shards *= self.mesh.shape.get(axis, 1)
|
|
68
|
+
else:
|
|
69
|
+
self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
|
|
65
70
|
|
|
66
71
|
def get_input_sharding(self, x: torchax.tensor.Tensor):
|
|
67
|
-
if self.
|
|
72
|
+
if self.enable_sp:
|
|
68
73
|
token_num = x.shape[0]
|
|
69
74
|
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
|
|
70
75
|
if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
|
|
@@ -74,7 +79,7 @@ class JaxCommonLinearConfig:
|
|
|
74
79
|
return self.input_sharding
|
|
75
80
|
|
|
76
81
|
def get_output_sharding(self, x: torchax.tensor.Tensor):
|
|
77
|
-
if self.
|
|
82
|
+
if self.enable_sp:
|
|
78
83
|
token_num = x.shape[0]
|
|
79
84
|
# NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
|
|
80
85
|
if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
|
|
@@ -20,7 +20,7 @@ from tpu_inference.layers.common.quant_methods import (COMPRESSED_TENSORS,
|
|
|
20
20
|
get_tpu_quant_method)
|
|
21
21
|
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
22
22
|
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
|
|
23
|
-
|
|
23
|
+
VllmCompressedTensorsMoEMethod
|
|
24
24
|
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
|
|
25
25
|
VllmCompressedTensorsW8A8Fp8
|
|
26
26
|
from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
|
|
@@ -113,8 +113,9 @@ class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
|
|
|
113
113
|
layer.scheme = scheme
|
|
114
114
|
return CompressedTensorsLinearMethod(self)
|
|
115
115
|
if isinstance(layer, FusedMoE):
|
|
116
|
-
|
|
117
|
-
|
|
116
|
+
layer.moe_config = self.get_moe_config(layer)
|
|
117
|
+
return VllmCompressedTensorsMoEMethod.get_moe_method(
|
|
118
|
+
self, layer, layer_name=prefix)
|
|
118
119
|
if isinstance(layer, Attention):
|
|
119
120
|
return CompressedTensorsKVCacheMethod(self)
|
|
120
121
|
return None
|