tpu-inference 0.11.1rc2__py3-none-any.whl → 0.11.1rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tpu_inference/kernels/collectives/__init__.py +0 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
- tpu_inference/kernels/collectives/util.py +47 -0
- tpu_inference/layers/__init__.py +0 -0
- tpu_inference/layers/common/__init__.py +0 -0
- tpu_inference/layers/common/attention_metadata.py +34 -0
- tpu_inference/layers/jax/__init__.py +0 -0
- tpu_inference/layers/jax/attention/__init__.py +0 -0
- tpu_inference/layers/jax/attention/attention.py +254 -0
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
- tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
- tpu_inference/layers/jax/attention_interface.py +356 -0
- tpu_inference/layers/jax/base.py +151 -0
- tpu_inference/layers/jax/binary_search.py +295 -0
- tpu_inference/layers/jax/constants.py +88 -0
- tpu_inference/layers/jax/layers.py +301 -0
- tpu_inference/layers/jax/misc.py +16 -0
- tpu_inference/layers/jax/moe/__init__.py +0 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
- tpu_inference/layers/jax/moe/moe.py +209 -0
- tpu_inference/layers/jax/rope.py +172 -0
- tpu_inference/layers/jax/rope_interface.py +214 -0
- tpu_inference/layers/jax/sample/__init__.py +0 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
- tpu_inference/layers/jax/sample/sampling.py +95 -0
- tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
- tpu_inference/layers/jax/sharding.py +406 -0
- tpu_inference/layers/jax/transformer_block.py +76 -0
- tpu_inference/layers/vllm/__init__.py +0 -0
- tpu_inference/layers/vllm/attention.py +184 -0
- tpu_inference/layers/vllm/fused_moe.py +399 -0
- tpu_inference/layers/vllm/linear_common.py +186 -0
- tpu_inference/layers/vllm/quantization/__init__.py +34 -0
- tpu_inference/layers/vllm/quantization/awq.py +207 -0
- tpu_inference/layers/vllm/quantization/common.py +105 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
- tpu_inference/layers/vllm/sharding.py +151 -0
- tpu_inference/models/common/__init__.py +0 -0
- tpu_inference/models/common/model_loader.py +433 -0
- {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/METADATA +6 -6
- {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/RECORD +50 -5
- {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/WHEEL +1 -1
- {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1rc2.dist-info → tpu_inference-0.11.1rc3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,399 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
from jax import numpy as jnp
|
|
5
|
+
from jax.experimental.pallas.ops.tpu.megablox.gmm import gmm
|
|
6
|
+
from jax.experimental.shard_map import shard_map
|
|
7
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
8
|
+
|
|
9
|
+
from tpu_inference.layers.vllm.linear_common import \
|
|
10
|
+
slice_sharded_tensor_for_concatenation
|
|
11
|
+
|
|
12
|
+
P = PartitionSpec
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _round_up_to_multiple_of_128_within_limit(x: int, limit: int) -> int:
|
|
16
|
+
"""
|
|
17
|
+
Rounds the given integer `x` up to the nearest multiple of 128, without exceeding
|
|
18
|
+
the specified `limit`.
|
|
19
|
+
|
|
20
|
+
If `x` is less than or equal to 128, returns 128.
|
|
21
|
+
If `x` is less than `limit`, returns the smallest multiple of 128 greater than or
|
|
22
|
+
equal to `x`.
|
|
23
|
+
If `x` is greater than or equal to `limit`, searches for the largest multiple of
|
|
24
|
+
128 less than or equal to `limit` (down to 512) that divides `x` evenly, and
|
|
25
|
+
returns it.
|
|
26
|
+
If no such candidate is found, returns `limit`.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
x (int): The integer to round up.
|
|
30
|
+
limit (int): The upper bound (must be a multiple of 128 and at least 128).
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
int: The rounded value according to the rules above.
|
|
34
|
+
|
|
35
|
+
Raises:
|
|
36
|
+
AssertionError: If `limit` is less than 128 or not a multiple of 128.
|
|
37
|
+
"""
|
|
38
|
+
assert limit >= 128 and limit % 128 == 0
|
|
39
|
+
if x <= 128:
|
|
40
|
+
return 128
|
|
41
|
+
if x < limit:
|
|
42
|
+
return (x + 127) // 128 * 128
|
|
43
|
+
for candidate in range(limit, 511, -128):
|
|
44
|
+
if x % candidate == 0:
|
|
45
|
+
return candidate
|
|
46
|
+
return limit
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _get_tiling_size_for_gmm_kernel(m: int, k: int, n: int,
|
|
50
|
+
g: int) -> tuple[int, int, int]:
|
|
51
|
+
"""
|
|
52
|
+
Calculate optimal tiling sizes for a GMM kernel in a Mixture of Experts
|
|
53
|
+
(MoE) setting.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
m (int): The total number of tokens.
|
|
57
|
+
n (int): The output feature dimension.
|
|
58
|
+
k (int): The input feature dimension.
|
|
59
|
+
g (int): The number of experts.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
tuple[int, int, int]: A tuple (tm, tk, tn)
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
# TODO(Chengji): increase the upper limit tiling size of m when we can set
|
|
66
|
+
# the vmem size to be used for gmm kernel.
|
|
67
|
+
# NOTE: In average each expert has m // g tokens, but as it might be unbalanced,
|
|
68
|
+
# here we doubled the token size when choosing tiling size of m. 2m//g can be
|
|
69
|
+
# either greater or less than 512. If there are 32 tokens and topk=2,
|
|
70
|
+
# m=topk * num_tokens=64, in this case, 2*m//g will be less than 512.
|
|
71
|
+
tm = _round_up_to_multiple_of_128_within_limit(2 * m // g, 512)
|
|
72
|
+
tm = min(tm, m) # there's a requirement that m % tm == 0
|
|
73
|
+
# k/n correspond to n_input_features/n_output_features in the matmul so they are
|
|
74
|
+
# normally greater than 2048, unless the num shards is large.
|
|
75
|
+
tk = _round_up_to_multiple_of_128_within_limit(k, 2048)
|
|
76
|
+
tn = _round_up_to_multiple_of_128_within_limit(n, 2048)
|
|
77
|
+
return tm, tk, tn
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def tensor_sharded_gmm_merged_column_parallel(
|
|
81
|
+
lhs: jax.Array, rhs: jax.Array, group_sizes: jax.Array,
|
|
82
|
+
transpose_rhs: bool, mesh: Mesh, intermediate_size: int) -> jax.Array:
|
|
83
|
+
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
|
|
84
|
+
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
|
|
85
|
+
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
|
|
86
|
+
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
|
|
87
|
+
|
|
88
|
+
_gmm = functools.partial(
|
|
89
|
+
gmm,
|
|
90
|
+
preferred_element_type=lhs.dtype,
|
|
91
|
+
tiling=(tm, tk, tn),
|
|
92
|
+
transpose_rhs=transpose_rhs,
|
|
93
|
+
group_offset=jnp.array(0),
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
gmm_result = shard_map(
|
|
97
|
+
_gmm,
|
|
98
|
+
mesh=mesh,
|
|
99
|
+
in_specs=(P(), P(None, "model", None), P()),
|
|
100
|
+
out_specs=(P(None, "model")),
|
|
101
|
+
check_rep=False,
|
|
102
|
+
)(lhs, rhs, group_sizes)
|
|
103
|
+
|
|
104
|
+
n_shards = mesh.shape["model"]
|
|
105
|
+
output_sizes = [intermediate_size, intermediate_size]
|
|
106
|
+
|
|
107
|
+
return slice_sharded_tensor_for_concatenation(gmm_result, output_sizes,
|
|
108
|
+
n_shards)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def tensor_sharded_gmm_row_parallel(
|
|
112
|
+
lhs: jax.Array,
|
|
113
|
+
rhs: jax.Array,
|
|
114
|
+
group_sizes: jax.Array,
|
|
115
|
+
transpose_rhs: bool,
|
|
116
|
+
mesh: Mesh,
|
|
117
|
+
) -> jax.Array:
|
|
118
|
+
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
|
|
119
|
+
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
|
|
120
|
+
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
|
|
121
|
+
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
|
|
122
|
+
|
|
123
|
+
_gmm = functools.partial(
|
|
124
|
+
gmm,
|
|
125
|
+
preferred_element_type=lhs.dtype,
|
|
126
|
+
tiling=(tm, tk, tn),
|
|
127
|
+
transpose_rhs=transpose_rhs,
|
|
128
|
+
group_offset=jnp.array(0),
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def _gmm_all_reduce(lhs, rhs, group_sizes):
|
|
132
|
+
r = _gmm(lhs, rhs, group_sizes)
|
|
133
|
+
return jax.lax.psum(r, axis_name="model")
|
|
134
|
+
|
|
135
|
+
return shard_map(
|
|
136
|
+
_gmm_all_reduce,
|
|
137
|
+
mesh=mesh,
|
|
138
|
+
in_specs=(P(None, "model"), P(None, None, "model"), P()),
|
|
139
|
+
out_specs=(P()),
|
|
140
|
+
check_rep=False,
|
|
141
|
+
)(lhs, rhs, group_sizes)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def expert_sharded_gmm(
|
|
145
|
+
lhs: jax.Array,
|
|
146
|
+
rhs: jax.Array,
|
|
147
|
+
group_sizes: jax.Array,
|
|
148
|
+
transpose_rhs: bool,
|
|
149
|
+
mesh: Mesh,
|
|
150
|
+
num_experts: int,
|
|
151
|
+
ep_size: int,
|
|
152
|
+
) -> jax.Array:
|
|
153
|
+
# adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
|
|
154
|
+
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
|
|
155
|
+
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
|
|
156
|
+
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
|
|
157
|
+
|
|
158
|
+
num_experts_per_shard = num_experts // ep_size
|
|
159
|
+
group_offset = jnp.arange(0, num_experts, num_experts_per_shard)
|
|
160
|
+
group_offset = jax.lax.with_sharding_constraint(
|
|
161
|
+
group_offset, NamedSharding(mesh, P("model")))
|
|
162
|
+
|
|
163
|
+
def _gmm(lhs, rhs, group_sizes, group_offset):
|
|
164
|
+
# Group offset for this shard. `group_offset` is sharded, and in this sharded
|
|
165
|
+
# function, it has only 1 element and `group_offset.shape` is (1,) but gmm kernel requires
|
|
166
|
+
# the group_offset to be a ()-shaped array, so we group_offset[0].
|
|
167
|
+
group_offset_of_shard = group_offset[0]
|
|
168
|
+
return gmm(lhs=lhs,
|
|
169
|
+
rhs=rhs,
|
|
170
|
+
group_sizes=group_sizes,
|
|
171
|
+
preferred_element_type=lhs.dtype,
|
|
172
|
+
tiling=(tm, tk, tn),
|
|
173
|
+
transpose_rhs=transpose_rhs,
|
|
174
|
+
group_offset=group_offset_of_shard)
|
|
175
|
+
|
|
176
|
+
# The result from gmm on each shard has the same shape, but only the rows for this shard has non-zero values. Taking below as an working example:
|
|
177
|
+
# A, A, A, A 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0
|
|
178
|
+
# A, A, A, A 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0
|
|
179
|
+
# A, A, A, A 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0
|
|
180
|
+
# 0, 0, 0, 0 B, B, B, B 0, 0, 0, 0 0, 0, 0, 0
|
|
181
|
+
# 0, 0, 0, 0 B, B, B, B 0, 0, 0, 0 0, 0, 0, 0
|
|
182
|
+
# 0, 0, 0, 0 0, 0, 0, 0 C, C, C, C 0, 0, 0, 0
|
|
183
|
+
# 0, 0, 0, 0 0, 0, 0, 0 C, C, C, C 0, 0, 0, 0
|
|
184
|
+
# 0, 0, 0, 0 0, 0, 0, 0 C, C, C, C 0, 0, 0, 0
|
|
185
|
+
# 0, 0, 0, 0 0, 0, 0, 0 C, C, C, C 0, 0, 0, 0
|
|
186
|
+
# 0, 0, 0, 0 0, 0, 0, 0 C, C, C, C 0, 0, 0, 0
|
|
187
|
+
# 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0 D, D, D, D
|
|
188
|
+
# 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0 D, D, D, D
|
|
189
|
+
# 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0 D, D, D, D
|
|
190
|
+
# 0, 0, 0, 0 0, 0, 0, 0 0, 0, 0, 0 D, D, D, D
|
|
191
|
+
# shard-0 shard-1 shard-2 shard-3
|
|
192
|
+
# The shard 0,1,2,3 each has 3 (A rows), 2 (B rows), 5 (C rows) and 4 (D rows).
|
|
193
|
+
gmm_res = shard_map(
|
|
194
|
+
_gmm,
|
|
195
|
+
mesh=mesh,
|
|
196
|
+
in_specs=(P(), P("model", None, None), P(), P("model")),
|
|
197
|
+
out_specs=(P("model", None)),
|
|
198
|
+
check_rep=False,
|
|
199
|
+
)(lhs, rhs, group_sizes, group_offset)
|
|
200
|
+
|
|
201
|
+
# For i-th shard, it is responsible groups (AKA experts) from i*num_experts_per_shard to (i+1)*num_experts_per_shard
|
|
202
|
+
# We sum them up to get total rows in that shard, and that is the size for shard to send to its peers. This is also
|
|
203
|
+
# the number of non-zero rows from the gmm results.
|
|
204
|
+
# In the working example, send_sizes would be [3, 2, 5, 4]
|
|
205
|
+
send_sizes = jnp.array([
|
|
206
|
+
group_sizes[i * num_experts_per_shard:(i + 1) *
|
|
207
|
+
num_experts_per_shard].sum() for i in range(ep_size)
|
|
208
|
+
])
|
|
209
|
+
# In the working example, input_offsets would be [0, 3, 5, 10]
|
|
210
|
+
input_offsets = jnp.concatenate((jnp.array([0]), send_sizes.cumsum()[:-1]))
|
|
211
|
+
output_offsets = input_offsets
|
|
212
|
+
recv_sizes = send_sizes
|
|
213
|
+
|
|
214
|
+
input_offsets = jax.lax.with_sharding_constraint(
|
|
215
|
+
input_offsets, NamedSharding(mesh, P("model")))
|
|
216
|
+
send_sizes = jax.lax.with_sharding_constraint(
|
|
217
|
+
send_sizes, NamedSharding(mesh, P("model")))
|
|
218
|
+
output_offsets = jax.lax.with_sharding_constraint(
|
|
219
|
+
output_offsets, NamedSharding(mesh, P("model")))
|
|
220
|
+
|
|
221
|
+
def _ragged_all_to_all(operand, input_offsets, send_sizes, output_offsets,
|
|
222
|
+
recv_sizes):
|
|
223
|
+
output = jnp.zeros_like(operand)
|
|
224
|
+
|
|
225
|
+
# input_offsets, send_sizes and output_offsets are sharded and there is only 1 elemnt in each shard, we
|
|
226
|
+
# are taking the 0-th element from them just so that jnp.repeat generates the arrays with correct shape.
|
|
227
|
+
input_offsets_of_shard = jnp.repeat(input_offsets[0], ep_size)
|
|
228
|
+
send_sizes_of_shard = jnp.repeat(send_sizes[0], ep_size)
|
|
229
|
+
output_offsets_of_shard = jnp.repeat(output_offsets[0], ep_size)
|
|
230
|
+
|
|
231
|
+
# recv_sizes is replicated across shards, because all the shards receive the same data and write to the
|
|
232
|
+
# output in the same way (same output_offsets and same recv_sizes) and thus generates replicated output.
|
|
233
|
+
recv_sizes_of_shard = recv_sizes
|
|
234
|
+
|
|
235
|
+
# In the working example, for each shard, the values of the offsets and sizes would be:
|
|
236
|
+
# shard-0 shard-1 shard-2 shard-3
|
|
237
|
+
# input_offsets_of_shard [0, 0, 0, 0] [3, 3, 3, 3] [5, 5, 5, 5] [10,10,10,10]
|
|
238
|
+
# send_sizes_of_shard [3, 3, 3, 3] [2, 2, 2, 2] [5, 5, 5, 5] [4, 4, 4, 4 ]
|
|
239
|
+
# output_offsets_of_shard [0, 0, 0, 0] [0, 0, 0, 0] [0, 0, 0, 0] [10,10,10,10]
|
|
240
|
+
# recv_sizes_of_shard [3, 2, 5, 4] [3, 2, 5, 4] [3, 2, 5, 4] [3, 2, 5, 4]
|
|
241
|
+
return jax.lax.ragged_all_to_all(operand,
|
|
242
|
+
output,
|
|
243
|
+
input_offsets_of_shard,
|
|
244
|
+
send_sizes_of_shard,
|
|
245
|
+
output_offsets_of_shard,
|
|
246
|
+
recv_sizes_of_shard,
|
|
247
|
+
axis_name="model")
|
|
248
|
+
|
|
249
|
+
# Use ragged_all_to_all to send the result from gmm for each expert to all the shards.
|
|
250
|
+
# In the working example, the result would be:
|
|
251
|
+
# A, A, A, A A, A, A, A A, A, A, A A, A, A, A
|
|
252
|
+
# A, A, A, A A, A, A, A A, A, A, A A, A, A, A
|
|
253
|
+
# A, A, A, A A, A, A, A A, A, A, A A, A, A, A
|
|
254
|
+
# B, B, B, B B, B, B, B B, B, B, B B, B, B, B
|
|
255
|
+
# B, B, B, B B, B, B, B B, B, B, B B, B, B, B
|
|
256
|
+
# C, C, C, C C, C, C, C C, C, C, C C, C, C, C
|
|
257
|
+
# C, C, C, C C, C, C, C C, C, C, C C, C, C, C
|
|
258
|
+
# C, C, C, C C, C, C, C C, C, C, C C, C, C, C
|
|
259
|
+
# C, C, C, C C, C, C, C C, C, C, C C, C, C, C
|
|
260
|
+
# C, C, C, C C, C, C, C C, C, C, C C, C, C, C
|
|
261
|
+
# D, D, D, D D, D, D, D D, D, D, D D, D, D, D
|
|
262
|
+
# D, D, D, D D, D, D, D D, D, D, D D, D, D, D
|
|
263
|
+
# D, D, D, D D, D, D, D D, D, D, D D, D, D, D
|
|
264
|
+
# D, D, D, D D, D, D, D D, D, D, D D, D, D, D
|
|
265
|
+
# shard-0 shard-1 shard-2 shard-3
|
|
266
|
+
return shard_map(
|
|
267
|
+
_ragged_all_to_all,
|
|
268
|
+
mesh=mesh,
|
|
269
|
+
in_specs=(P("model", None), P("model"), P("model"), P("model"), P()),
|
|
270
|
+
out_specs=(P()),
|
|
271
|
+
check_rep=False,
|
|
272
|
+
)(gmm_res, input_offsets, send_sizes, output_offsets, recv_sizes)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def jax_fused_moe_func(
|
|
276
|
+
hidden_states: jax.Array,
|
|
277
|
+
w1: jax.Array,
|
|
278
|
+
w2: jax.Array,
|
|
279
|
+
gating_output: jax.Array,
|
|
280
|
+
topk: int,
|
|
281
|
+
global_num_experts: int,
|
|
282
|
+
renormalize: bool,
|
|
283
|
+
reduce_results: bool,
|
|
284
|
+
mesh: Mesh,
|
|
285
|
+
use_ep: bool,
|
|
286
|
+
):
|
|
287
|
+
"""
|
|
288
|
+
Args:
|
|
289
|
+
hidden_states: [*, hidden_size]
|
|
290
|
+
w1: [num_experts, intermediate_size * 2, hidden_size]
|
|
291
|
+
w2: [num_experts, hidden_size, intermediate_size]
|
|
292
|
+
gating_output: [*, num_experts]
|
|
293
|
+
"""
|
|
294
|
+
# adapted from https://github.com/vllm-project/vllm/blob/29fa5cac1cd731026f59084d93a822921507573c/vllm/model_executor/layers/fused_moe/moe_pallas.py#L26
|
|
295
|
+
orig_shape = hidden_states.shape
|
|
296
|
+
hidden_size = hidden_states.shape[-1]
|
|
297
|
+
num_tokens = hidden_states.size // hidden_size
|
|
298
|
+
assert global_num_experts == w1.shape[0]
|
|
299
|
+
ep_size = mesh.shape["model"] # only used if use_ep is True.
|
|
300
|
+
intermediate_size = w2.shape[-1]
|
|
301
|
+
dtype = hidden_states.dtype
|
|
302
|
+
assert (num_tokens * topk) % 16 == 0, (
|
|
303
|
+
"The kernel requires num_tokens * topk to be a multiple of "
|
|
304
|
+
f"16 but got {num_tokens}*{topk}={num_tokens*topk}")
|
|
305
|
+
|
|
306
|
+
hidden_states = hidden_states.reshape(num_tokens, hidden_size)
|
|
307
|
+
gating_output = gating_output.reshape(num_tokens, global_num_experts)
|
|
308
|
+
|
|
309
|
+
topk_weights = jax.nn.softmax(gating_output.astype(jnp.float32), axis=-1)
|
|
310
|
+
topk_weights, topk_indices = jax.lax.top_k(topk_weights, k=topk)
|
|
311
|
+
if renormalize:
|
|
312
|
+
topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdims=True)
|
|
313
|
+
topk_weights = topk_weights.astype(dtype)
|
|
314
|
+
|
|
315
|
+
topk_indices_flat = topk_indices.flatten()
|
|
316
|
+
topk_argsort_indices = jnp.argsort(topk_indices_flat)
|
|
317
|
+
topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
|
|
318
|
+
token_indices = jnp.arange(num_tokens, dtype=jnp.int32).repeat(topk)
|
|
319
|
+
token_indices_sorted = token_indices[topk_argsort_indices]
|
|
320
|
+
group_sizes = jnp.bincount(topk_indices_flat, length=global_num_experts)
|
|
321
|
+
|
|
322
|
+
x = hidden_states[token_indices_sorted]
|
|
323
|
+
|
|
324
|
+
if use_ep:
|
|
325
|
+
x = expert_sharded_gmm(x,
|
|
326
|
+
w1,
|
|
327
|
+
group_sizes,
|
|
328
|
+
transpose_rhs=True,
|
|
329
|
+
mesh=mesh,
|
|
330
|
+
num_experts=global_num_experts,
|
|
331
|
+
ep_size=ep_size)
|
|
332
|
+
x1, x2 = x[..., :intermediate_size], x[..., intermediate_size:]
|
|
333
|
+
else:
|
|
334
|
+
x1, x2 = tensor_sharded_gmm_merged_column_parallel(
|
|
335
|
+
x,
|
|
336
|
+
w1,
|
|
337
|
+
group_sizes,
|
|
338
|
+
transpose_rhs=True,
|
|
339
|
+
mesh=mesh,
|
|
340
|
+
intermediate_size=intermediate_size)
|
|
341
|
+
|
|
342
|
+
x = jax.nn.silu(x1) * x2
|
|
343
|
+
|
|
344
|
+
if use_ep:
|
|
345
|
+
x = expert_sharded_gmm(x,
|
|
346
|
+
w2,
|
|
347
|
+
group_sizes,
|
|
348
|
+
transpose_rhs=True,
|
|
349
|
+
mesh=mesh,
|
|
350
|
+
num_experts=global_num_experts,
|
|
351
|
+
ep_size=ep_size)
|
|
352
|
+
else:
|
|
353
|
+
x = jax.lax.with_sharding_constraint(
|
|
354
|
+
x, NamedSharding(mesh, P(None, "model")))
|
|
355
|
+
x = tensor_sharded_gmm_row_parallel(x,
|
|
356
|
+
w2,
|
|
357
|
+
group_sizes,
|
|
358
|
+
transpose_rhs=True,
|
|
359
|
+
mesh=mesh)
|
|
360
|
+
|
|
361
|
+
x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
|
|
362
|
+
x = x * jnp.expand_dims(topk_weights, axis=-1)
|
|
363
|
+
x = x.sum(axis=-2)
|
|
364
|
+
x = x.reshape(orig_shape)
|
|
365
|
+
|
|
366
|
+
if reduce_results:
|
|
367
|
+
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
|
|
368
|
+
return x
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def jax_fused_moe_func_padded(hidden_states: jax.Array, w1: jax.Array,
|
|
372
|
+
w2: jax.Array, gating_output: jax.Array,
|
|
373
|
+
topk: int, global_num_experts: int,
|
|
374
|
+
renormalize: bool, reduce_results: bool,
|
|
375
|
+
mesh: Mesh, use_ep: bool):
|
|
376
|
+
# TODO(fanhongmin@google.com): Once the jax runner pads the input, we no longer need this.
|
|
377
|
+
hidden_size = hidden_states.shape[-1]
|
|
378
|
+
num_tokens = hidden_states.size // hidden_size
|
|
379
|
+
if num_tokens * topk < 16:
|
|
380
|
+
assert 16 % (num_tokens *
|
|
381
|
+
topk) == 0, f"Cannot pad to 16: {num_tokens=}, {topk=}"
|
|
382
|
+
n_repeats = 16 // (num_tokens * topk)
|
|
383
|
+
|
|
384
|
+
reps = (n_repeats, ) + (1, ) * (hidden_states.ndim - 1)
|
|
385
|
+
expanded_hidden_states = jnp.tile(hidden_states, reps)
|
|
386
|
+
|
|
387
|
+
reps = (n_repeats, ) + (1, ) * (gating_output.ndim - 1)
|
|
388
|
+
expanded_gating_output = jnp.tile(gating_output, reps)
|
|
389
|
+
|
|
390
|
+
expanded_x = jax_fused_moe_func(expanded_hidden_states, w1, w2,
|
|
391
|
+
expanded_gating_output, topk,
|
|
392
|
+
global_num_experts, renormalize,
|
|
393
|
+
reduce_results, mesh, use_ep)
|
|
394
|
+
x = expanded_x[:hidden_states.shape[0]]
|
|
395
|
+
return x
|
|
396
|
+
else:
|
|
397
|
+
return jax_fused_moe_func(hidden_states, w1, w2, gating_output, topk,
|
|
398
|
+
global_num_experts, renormalize,
|
|
399
|
+
reduce_results, mesh, use_ep)
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
from typing import Optional, Union
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import torch
|
|
6
|
+
from jax.experimental.shard_map import shard_map
|
|
7
|
+
from jax.sharding import Mesh, NamedSharding
|
|
8
|
+
from jax.sharding import PartitionSpec as P
|
|
9
|
+
from torchax.interop import torch_view
|
|
10
|
+
from torchax.ops.mappings import t2j
|
|
11
|
+
|
|
12
|
+
from tpu_inference.kernels.quantized_matmul.kernel import \
|
|
13
|
+
quantized_matmul_kernel
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def sharded_quantized_matmul(x: jax.Array, w_q: jax.Array, w_s: jax.Array,
|
|
17
|
+
mesh: Mesh, weight_sharding: P):
|
|
18
|
+
out_axis, in_axis = weight_sharding
|
|
19
|
+
x_sharding = P(None, in_axis)
|
|
20
|
+
scale_sharding = P(out_axis, )
|
|
21
|
+
out_sharding = P(None, out_axis)
|
|
22
|
+
|
|
23
|
+
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, x_sharding))
|
|
24
|
+
|
|
25
|
+
def wrapper(x, w_q, w_s):
|
|
26
|
+
output = quantized_matmul_kernel(x, w_q, w_s, x_q_dtype=w_q.dtype)
|
|
27
|
+
if in_axis:
|
|
28
|
+
output = jax.lax.psum(output, axis_name=in_axis)
|
|
29
|
+
return output
|
|
30
|
+
|
|
31
|
+
return shard_map(wrapper,
|
|
32
|
+
mesh=mesh,
|
|
33
|
+
in_specs=(x_sharding, weight_sharding, scale_sharding),
|
|
34
|
+
out_specs=(out_sharding),
|
|
35
|
+
check_rep=False)(x, w_q, w_s)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def reorder_concatenated_tensor_for_sharding(concatenated_tensor: jax.Array,
|
|
39
|
+
split_sizes: list[int],
|
|
40
|
+
n_shards: int, dim: int):
|
|
41
|
+
"""
|
|
42
|
+
Reorder a replicated concatenated tensor such that when sharded on multiple chips, each shard is a concatenation of the shards of the individual tensors.
|
|
43
|
+
For example, let the concatenated_tensor be:
|
|
44
|
+
AAAAAAAAAAAABBBBBBBBCCCC
|
|
45
|
+
12 As 8 Bs 4 Cs
|
|
46
|
+
and let the split_sizes = [12, 8, 4] and n_shards = 4.
|
|
47
|
+
The output is:
|
|
48
|
+
AAABBCAAABBCAAABBCAAABBC
|
|
49
|
+
In other words, it reorders the input tensor into 4 segements, with each segment corresponding to a shard and being AAABBC.
|
|
50
|
+
Args:
|
|
51
|
+
concatenated_tensor: the tensor, concatenated on the dimension specified by `dim`.
|
|
52
|
+
split_sizes: each individual tensor's size on the dimension specified by `dim`.
|
|
53
|
+
n_shards: num of shards.
|
|
54
|
+
dim: the dimension on which the concatenated_tensor is concatenated.
|
|
55
|
+
"""
|
|
56
|
+
# Split the concatenated tensor into individual tensors.
|
|
57
|
+
split_tensors = []
|
|
58
|
+
start_offset = 0
|
|
59
|
+
old_shape = concatenated_tensor.shape
|
|
60
|
+
# New shape ensures each split_tensor[i] maps to a tensor in ith shards
|
|
61
|
+
new_shape = old_shape[:dim] + (n_shards, -1) + old_shape[dim + 1:]
|
|
62
|
+
for split_size in split_sizes:
|
|
63
|
+
split_tensor = jax.lax.slice_in_dim(concatenated_tensor,
|
|
64
|
+
start_offset,
|
|
65
|
+
start_offset + split_size,
|
|
66
|
+
axis=dim)
|
|
67
|
+
split_tensors.append(split_tensor.reshape(new_shape))
|
|
68
|
+
start_offset += split_size
|
|
69
|
+
# While maintaining 0th dim as a shard dim, we concatenate along 1th dim to
|
|
70
|
+
# to create concatenated tnensor where 0th dim maps to shard dim.
|
|
71
|
+
reordered_tensor = jnp.concatenate(split_tensors, axis=dim + 1)
|
|
72
|
+
return reordered_tensor.reshape(old_shape)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def slice_sharded_tensor_for_concatenation(sharded_tensor: jax.Array,
|
|
76
|
+
split_sizes: list[int],
|
|
77
|
+
n_shards: int):
|
|
78
|
+
"""
|
|
79
|
+
Slice the input tensor which is sharded on multiple chips (on the last dim) into individual tensors with the same sharding.
|
|
80
|
+
For example, let the sharded_tensor be:
|
|
81
|
+
AAABBC | AAABBC | AAABBC | AAABBC
|
|
82
|
+
Shard0 Shard1 Shard2 Shard3
|
|
83
|
+
and let the split_sizes = [12, 8, 4] and n_shards = 4.
|
|
84
|
+
The output is a list of 3 tensors:
|
|
85
|
+
AAA | AAA | AAA | AAA
|
|
86
|
+
BB | BB | BB | BB
|
|
87
|
+
C | C | C | C
|
|
88
|
+
Shard0 Shard1 Shard2 Shard3
|
|
89
|
+
In other words, each individual tensor is a slice of the input tensor with the same sharding.
|
|
90
|
+
Args:
|
|
91
|
+
sharded_tensor: the input tensor, sharded on the last dim.
|
|
92
|
+
split_sizes: each individual tensor's size on the last dim.
|
|
93
|
+
n_shards: num of shards.
|
|
94
|
+
"""
|
|
95
|
+
new_shape = sharded_tensor.shape[:-1] + (n_shards, -1)
|
|
96
|
+
# New shape ensures each sharded_tensor[:, i] maps to a tensor in ith shards
|
|
97
|
+
sharded_tensor = sharded_tensor.reshape(new_shape)
|
|
98
|
+
|
|
99
|
+
split_tensors = []
|
|
100
|
+
start_offset = 0
|
|
101
|
+
for split_size in split_sizes:
|
|
102
|
+
assert split_size % n_shards == 0
|
|
103
|
+
sz = split_size // n_shards # size of this split tensor per shard
|
|
104
|
+
end_offset = start_offset + sz
|
|
105
|
+
# Because we are slicing over last dim, sharding dim remains intact.
|
|
106
|
+
# Therefore, splitting happens locally.
|
|
107
|
+
split_tensor = sharded_tensor[..., start_offset:end_offset]
|
|
108
|
+
split_tensors.append(split_tensor.reshape(new_shape[:-2] + (-1, )))
|
|
109
|
+
start_offset = end_offset
|
|
110
|
+
|
|
111
|
+
return split_tensors
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def torch_to_jax_param(
|
|
115
|
+
tensor: torch.Tensor,
|
|
116
|
+
sharding: NamedSharding,
|
|
117
|
+
output_sizes: Optional[int],
|
|
118
|
+
n_shards: int,
|
|
119
|
+
fused: bool,
|
|
120
|
+
dim: int = 0,
|
|
121
|
+
jax_dtype: Optional[jnp.dtype] = None,
|
|
122
|
+
) -> Union[torch.nn.Parameter, torch.nn.ParameterList]:
|
|
123
|
+
if output_sizes is None:
|
|
124
|
+
output_sizes = [tensor.shape[0]]
|
|
125
|
+
|
|
126
|
+
tensor = t2j(tensor, use_dlpack=False)
|
|
127
|
+
if jax_dtype:
|
|
128
|
+
tensor = tensor.astype(jax_dtype)
|
|
129
|
+
|
|
130
|
+
if fused:
|
|
131
|
+
tensor = reorder_concatenated_tensor_for_sharding(
|
|
132
|
+
tensor, output_sizes, n_shards, dim)
|
|
133
|
+
tensor = jax.device_put(tensor, sharding)
|
|
134
|
+
param = torch.nn.Parameter(torch_view(tensor), requires_grad=False)
|
|
135
|
+
else:
|
|
136
|
+
tensors = []
|
|
137
|
+
start_offset = 0
|
|
138
|
+
for size in output_sizes:
|
|
139
|
+
end_offset = start_offset + size
|
|
140
|
+
|
|
141
|
+
tensor_split = jax.lax.slice_in_dim(tensor,
|
|
142
|
+
start_offset,
|
|
143
|
+
end_offset,
|
|
144
|
+
axis=dim)
|
|
145
|
+
tensor_split = jax.device_put(tensor_split, sharding)
|
|
146
|
+
tensor_split = torch.nn.Parameter(torch_view(tensor_split),
|
|
147
|
+
requires_grad=False)
|
|
148
|
+
tensors.append(tensor_split)
|
|
149
|
+
|
|
150
|
+
start_offset = end_offset
|
|
151
|
+
param = torch.nn.ParameterList(tensors)
|
|
152
|
+
return param
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
MODEL_MATMUL_FUSION_TRUTH_TABLE = {
|
|
156
|
+
("Qwen/Qwen2.5-7B-Instruct", 1024, 1, "QKVParallelLinear"):
|
|
157
|
+
True,
|
|
158
|
+
("Qwen/Qwen2.5-7B-Instruct", 1024, 1, "MergedColumnParallelLinear"):
|
|
159
|
+
False,
|
|
160
|
+
("Qwen/Qwen2.5-7B-Instruct", 2048, 1, "QKVParallelLinear"):
|
|
161
|
+
False,
|
|
162
|
+
("Qwen/Qwen2.5-7B-Instruct", 2048, 1, "MergedColumnParallelLinear"):
|
|
163
|
+
False,
|
|
164
|
+
("meta-llama/Llama-3.1-8B-Instruct", 1024, 1, "QKVParallelLinear"):
|
|
165
|
+
False,
|
|
166
|
+
("meta-llama/Llama-3.1-8B-Instruct", 1024, 1, "MergedColumnParallelLinear"):
|
|
167
|
+
False,
|
|
168
|
+
("meta-llama/Llama-3.1-8B-Instruct", 2048, 1, "QKVParallelLinear"):
|
|
169
|
+
False,
|
|
170
|
+
("meta-llama/Llama-3.1-8B-Instruct", 2048, 1, "MergedColumnParallelLinear"):
|
|
171
|
+
False,
|
|
172
|
+
("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 1024, 1, "QKVParallelLinear"):
|
|
173
|
+
False,
|
|
174
|
+
("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 1024, 1, "MergedColumnParallelLinear"):
|
|
175
|
+
False,
|
|
176
|
+
("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 2048, 1, "QKVParallelLinear"):
|
|
177
|
+
False,
|
|
178
|
+
("RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8", 2048, 1, "MergedColumnParallelLinear"):
|
|
179
|
+
False,
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def get_model_matmul_fusion_assignment(model_name: str, batch_size: int,
|
|
184
|
+
tp_size: int, layer_name: str):
|
|
185
|
+
key = (model_name, batch_size, tp_size, layer_name)
|
|
186
|
+
return MODEL_MATMUL_FUSION_TRUTH_TABLE.get(key, True)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
|
|
3
|
+
from jax.sharding import Mesh
|
|
4
|
+
from vllm.config import VllmConfig
|
|
5
|
+
from vllm.model_executor.layers.quantization.base_config import \
|
|
6
|
+
QuantizationConfig
|
|
7
|
+
|
|
8
|
+
from tpu_inference.layers.vllm.quantization.awq import VllmAWQConfig
|
|
9
|
+
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
10
|
+
from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
|
|
11
|
+
VllmCompressedTensorsConfig # noqa: E501
|
|
12
|
+
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
13
|
+
VllmUnquantizedConfig
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_tpu_quantization_config(vllm_config: VllmConfig,
|
|
17
|
+
mesh: Mesh) -> QuantizationConfig:
|
|
18
|
+
model_config = copy.deepcopy(vllm_config.model_config)
|
|
19
|
+
# TODO(kyuyeunk): Add support for "tpu_int8".
|
|
20
|
+
method_to_config: dict[str, str] = {
|
|
21
|
+
None: VllmUnquantizedConfig,
|
|
22
|
+
"compressed-tensors": VllmCompressedTensorsConfig,
|
|
23
|
+
"awq": VllmAWQConfig,
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
if model_config.quantization not in method_to_config:
|
|
27
|
+
raise NotImplementedError
|
|
28
|
+
quant_config = method_to_config[model_config.quantization]
|
|
29
|
+
assert issubclass(quant_config, JaxCommonConfig)
|
|
30
|
+
quant_config.set_configs(vllm_config, mesh)
|
|
31
|
+
|
|
32
|
+
model_config.quantization = quant_config.get_name()
|
|
33
|
+
return VllmConfig.get_quantization_config(model_config,
|
|
34
|
+
vllm_config.load_config)
|