tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511130813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (67) hide show
  1. tests/kernels/fused_moe_v1_test.py +34 -303
  2. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
  3. tests/lora/test_layers.py +6 -0
  4. tests/lora/utils.py +8 -0
  5. tests/test_utils.py +16 -24
  6. tpu_inference/__init__.py +3 -22
  7. tpu_inference/core/core_tpu.py +9 -17
  8. tpu_inference/core/disagg_utils.py +8 -6
  9. tpu_inference/distributed/tpu_connector.py +4 -3
  10. tpu_inference/distributed/utils.py +2 -3
  11. tpu_inference/envs.py +8 -61
  12. tpu_inference/executors/ray_distributed_executor.py +11 -31
  13. tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
  15. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +143 -287
  16. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -7
  17. tpu_inference/layers/jax/attention/attention.py +1 -1
  18. tpu_inference/layers/{common → jax}/attention_interface.py +2 -8
  19. tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
  20. tpu_inference/layers/jax/sample/sampling.py +2 -2
  21. tpu_inference/layers/{common → jax}/sharding.py +5 -5
  22. tpu_inference/layers/vllm/attention.py +1 -1
  23. tpu_inference/layers/vllm/fused_moe.py +208 -170
  24. tpu_inference/layers/vllm/quantization/__init__.py +3 -7
  25. tpu_inference/layers/vllm/quantization/awq.py +3 -4
  26. tpu_inference/layers/vllm/quantization/common.py +1 -6
  27. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +2 -4
  28. tpu_inference/layers/vllm/quantization/unquantized.py +67 -62
  29. tpu_inference/layers/vllm/sharding.py +2 -2
  30. tpu_inference/lora/torch_punica_tpu.py +2 -1
  31. tpu_inference/mock/__init__.py +0 -0
  32. tpu_inference/mock/vllm_config_utils.py +28 -0
  33. tpu_inference/mock/vllm_envs.py +1219 -0
  34. tpu_inference/mock/vllm_logger.py +212 -0
  35. tpu_inference/mock/vllm_logging_utils.py +15 -0
  36. tpu_inference/models/common/model_loader.py +12 -46
  37. tpu_inference/models/jax/llama3.py +3 -4
  38. tpu_inference/models/jax/llama_eagle3.py +5 -8
  39. tpu_inference/models/jax/phi3.py +376 -0
  40. tpu_inference/models/jax/qwen2.py +2 -3
  41. tpu_inference/models/jax/qwen2_5_vl.py +50 -165
  42. tpu_inference/models/jax/qwen3.py +2 -3
  43. tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
  44. tpu_inference/models/jax/utils/weight_utils.py +143 -198
  45. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -32
  46. tpu_inference/platforms/tpu_platform.py +34 -47
  47. tpu_inference/runner/compilation_manager.py +60 -145
  48. tpu_inference/runner/kv_cache.py +2 -2
  49. tpu_inference/runner/kv_cache_manager.py +18 -17
  50. tpu_inference/runner/persistent_batch_manager.py +2 -40
  51. tpu_inference/runner/structured_decoding_manager.py +3 -2
  52. tpu_inference/runner/tpu_runner.py +135 -283
  53. tpu_inference/runner/utils.py +2 -2
  54. tpu_inference/spec_decode/jax/eagle3.py +21 -71
  55. tpu_inference/tpu_info.py +3 -4
  56. tpu_inference/utils.py +15 -38
  57. tpu_inference/worker/tpu_worker.py +26 -163
  58. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/METADATA +3 -4
  59. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/RECORD +63 -61
  60. tests/test_envs.py +0 -203
  61. tpu_inference/layers/common/quant_methods.py +0 -8
  62. tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
  63. tpu_inference/models/jax/llama_guard_4.py +0 -361
  64. /tpu_inference/layers/{common → jax}/binary_search.py +0 -0
  65. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/WHEEL +0 -0
  66. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/licenses/LICENSE +0 -0
  67. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@ import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
17
17
  import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
18
18
  from tpu_inference.kernels.flash_attention.kernel import flash_attention
19
19
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
20
- from tpu_inference.layers.common.sharding import ShardingAxisName
20
+ from tpu_inference.layers.jax.sharding import ShardingAxisName
21
21
  from tpu_inference.utils import get_megacore
22
22
 
23
23
  MAX_ALLOWED_PAGE_INDICES_N = (
@@ -308,13 +308,7 @@ def sharded_ragged_paged_attention(
308
308
  args = (q, k, v, kv_cache, kv_lens, page_indices, cu_q_lens, distribution)
309
309
 
310
310
  use_hd64 = q.shape[-1] == 64
311
-
312
- func = ragged_paged_attention
313
- if use_hd64:
314
- func = functools.partial(ragged_paged_attention_hd64,
315
- strict_sliding_window=False)
316
- else:
317
- func = ragged_paged_attention
311
+ func = ragged_paged_attention_hd64 if use_hd64 else ragged_paged_attention
318
312
 
319
313
  if attention_sink is not None:
320
314
  if not use_hd64:
@@ -12,7 +12,7 @@ import jax
12
12
  import jax.numpy as jnp
13
13
  import numpy as np
14
14
 
15
- from tpu_inference.layers.common.binary_search import topk_mask, topp_mask
15
+ from tpu_inference.layers.jax.binary_search import topk_mask, topp_mask
16
16
  from tpu_inference.layers.jax.sample.sampling_metadata import \
17
17
  TPUSupportedSamplingMetadata
18
18
 
@@ -6,10 +6,10 @@ from jax.sharding import Mesh, NamedSharding
6
6
  from jax.sharding import PartitionSpec as P
7
7
  from vllm.v1.outputs import LogprobsTensors
8
8
 
9
- from tpu_inference.layers.common.binary_search import topk_mask, topp_mask
10
- from tpu_inference.layers.common.sharding import ShardingAxisName
9
+ from tpu_inference.layers.jax.binary_search import topk_mask, topp_mask
11
10
  from tpu_inference.layers.jax.sample.sampling_metadata import \
12
11
  TPUSupportedSamplingMetadata
12
+ from tpu_inference.layers.jax.sharding import ShardingAxisName
13
13
 
14
14
  _SAMPLING_EPS = 1e-5
15
15
 
@@ -1,5 +1,6 @@
1
1
  import json
2
2
  import math
3
+ import os
3
4
  from dataclasses import asdict, dataclass
4
5
  from typing import TYPE_CHECKING, List, Optional
5
6
 
@@ -7,7 +8,7 @@ import jax.numpy as jnp
7
8
  import numpy as np
8
9
  from jax.sharding import Mesh
9
10
 
10
- from tpu_inference import envs, utils
11
+ from tpu_inference import utils
11
12
 
12
13
  if TYPE_CHECKING:
13
14
  from vllm.v1.configs.vllm_config import VllmConfig
@@ -47,7 +48,7 @@ class ShardingAxisName2D:
47
48
 
48
49
 
49
50
  try:
50
- _use_base_sharding = envs.NEW_MODEL_DESIGN
51
+ _use_base_sharding = os.getenv("NEW_MODEL_DESIGN", False)
51
52
  if _use_base_sharding:
52
53
  ShardingAxisName = ShardingAxisNameBase
53
54
  else:
@@ -165,10 +166,9 @@ class ShardingConfigManager:
165
166
  f"LoRA is not supported with data parallelism "
166
167
  f"(DP size: {total_dp_size}). Please disable LoRA or "
167
168
  f"set data parallelism to 1.")
168
- if sharding_strategy.attention_data_parallelism > 1:
169
- if not envs.NEW_MODEL_DESIGN:
169
+ if not os.environ.get("NEW_MODEL_DESIGN", False):
170
170
  raise ValueError(
171
- "Must run Attention DP with NEW_MODEL_DESIGN enabled. Please set the "
171
+ "Must run DP with NEW_MODEL_DESIGN enabled. Please set the "
172
172
  "NEW_MODEL_DESIGN=True.")
173
173
 
174
174
  @property
@@ -13,8 +13,8 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
13
13
  AttentionLayer, AttentionType)
14
14
 
15
15
  from tpu_inference import utils
16
- from tpu_inference.layers.common.attention_interface import attention
17
16
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
17
+ from tpu_inference.layers.jax.attention_interface import attention
18
18
  from tpu_inference.logger import init_logger
19
19
  from tpu_inference.models.vllm.vllm_model_wrapper_context import \
20
20
  get_vllm_model_wrapper_context
@@ -2,16 +2,17 @@ import functools
2
2
 
3
3
  import jax
4
4
  from jax import numpy as jnp
5
- from jax import shard_map
6
5
  from jax.experimental.pallas.ops.tpu.megablox.gmm import gmm
7
- from jax.sharding import Mesh
8
- from jax.sharding import PartitionSpec as P
6
+ from jax.experimental.shard_map import shard_map
7
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
9
8
 
10
9
  from tpu_inference.layers.vllm.linear_common import \
11
10
  slice_sharded_tensor_for_concatenation
12
11
 
12
+ P = PartitionSpec
13
13
 
14
- def activation_fn(activation: str, x1: jax.Array, x2: jax.Array) -> jax.Array:
14
+
15
+ def activation_fn(activation: str, x1, x2):
15
16
  match activation:
16
17
  case "silu":
17
18
  return jax.nn.silu(x1) * x2
@@ -22,10 +23,7 @@ def activation_fn(activation: str, x1: jax.Array, x2: jax.Array) -> jax.Array:
22
23
  f"FusedMoE does not support {activation} activation")
23
24
 
24
25
 
25
- def _swigluoai(x1: jax.Array,
26
- x2: jax.Array,
27
- alpha=1.702,
28
- limit=7.0) -> jax.Array:
26
+ def _swigluoai(x1, x2, alpha=1.702, limit=7.0):
29
27
  x1 = jnp.clip(x1, a_max=limit)
30
28
  x2 = jnp.clip(x2, a_min=-limit, a_max=limit)
31
29
 
@@ -105,53 +103,40 @@ def tensor_sharded_gmm_merged_column_parallel(
105
103
  rhs: jax.Array,
106
104
  rhs_bias: jax.Array | None,
107
105
  group_sizes: jax.Array,
106
+ transpose_rhs: bool,
108
107
  mesh: Mesh,
109
- ) -> tuple[jax.Array, jax.Array]:
110
-
111
- def _gmm(lhs, rhs, group_sizes):
112
- m, g, n, k = lhs.shape[0], *rhs.shape
113
- tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
114
- return gmm(
115
- lhs,
116
- rhs,
117
- group_sizes,
118
- preferred_element_type=lhs.dtype,
119
- tiling=(tm, tk, tn),
120
- transpose_rhs=True,
121
- group_offset=jnp.array(0),
122
- )
108
+ intermediate_size: int,
109
+ ) -> jax.Array:
110
+ # adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
111
+ m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
112
+ n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
113
+ tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
114
+
115
+ _gmm = functools.partial(
116
+ gmm,
117
+ preferred_element_type=lhs.dtype,
118
+ tiling=(tm, tk, tn),
119
+ transpose_rhs=transpose_rhs,
120
+ group_offset=jnp.array(0),
121
+ )
123
122
 
124
123
  gmm_result = shard_map(
125
124
  _gmm,
126
125
  mesh=mesh,
127
- in_specs=(P("data", None), P(None, "model", None), P("data")),
128
- out_specs=(P("data", "model")),
129
- check_vma=False,
126
+ in_specs=(P(), P(None, "model", None), P()),
127
+ out_specs=(P(None, "model")),
128
+ check_rep=False,
130
129
  )(lhs, rhs, group_sizes)
131
130
 
132
131
  if rhs_bias is not None:
132
+ rhs_bis = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
133
+ gmm_result = (gmm_result + rhs_bis).astype(gmm_result.dtype)
133
134
 
134
- def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
135
- rhs_bias = jnp.repeat(
136
- rhs_bias_local,
137
- group_sizes_global,
138
- 0,
139
- total_repeat_length=gmm_result_local.shape[0])
140
- return gmm_result_local + rhs_bias
141
-
142
- gmm_result = shard_map(
143
- _add_bias,
144
- mesh=mesh,
145
- in_specs=(P("data", "model"), P(None, "model"), P("data")),
146
- out_specs=(P("data", "model")),
147
- )(gmm_result, rhs_bias, group_sizes)
148
- gmm_result = gmm_result.astype(lhs.dtype)
149
-
150
- tp_size = mesh.shape["model"]
151
- intermediate_size = gmm_result.shape[-1] // 2
135
+ n_shards = mesh.shape["model"]
152
136
  output_sizes = [intermediate_size, intermediate_size]
137
+
153
138
  return slice_sharded_tensor_for_concatenation(gmm_result, output_sizes,
154
- tp_size)
139
+ n_shards)
155
140
 
156
141
 
157
142
  def tensor_sharded_gmm_row_parallel(
@@ -159,75 +144,74 @@ def tensor_sharded_gmm_row_parallel(
159
144
  rhs: jax.Array,
160
145
  rhs_bias: jax.Array | None,
161
146
  group_sizes: jax.Array,
147
+ transpose_rhs: bool,
162
148
  mesh: Mesh,
163
149
  ) -> jax.Array:
150
+ # adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
151
+ m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
152
+ n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
153
+ tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
154
+
155
+ _gmm = functools.partial(
156
+ gmm,
157
+ preferred_element_type=lhs.dtype,
158
+ tiling=(tm, tk, tn),
159
+ transpose_rhs=transpose_rhs,
160
+ group_offset=jnp.array(0),
161
+ )
164
162
 
165
163
  def _gmm_all_reduce(lhs, rhs, group_sizes):
166
- m, g, n, k = lhs.shape[0], *rhs.shape
167
- tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
168
- out = gmm(
169
- lhs,
170
- rhs,
171
- group_sizes,
172
- preferred_element_type=lhs.dtype,
173
- tiling=(tm, tk, tn),
174
- transpose_rhs=True,
175
- group_offset=jnp.array(0),
176
- )
177
- return jax.lax.psum(out, axis_name="model")
164
+ r = _gmm(lhs, rhs, group_sizes)
165
+ return jax.lax.psum(r, axis_name="model")
178
166
 
179
167
  gmm_result = shard_map(
180
168
  _gmm_all_reduce,
181
169
  mesh=mesh,
182
- in_specs=(P("data", "model"), P(None, None, "model"), P("data")),
183
- out_specs=(P("data")),
184
- check_vma=False,
170
+ in_specs=(P(None, "model"), P(None, None, "model"), P()),
171
+ out_specs=(P()),
172
+ check_rep=False,
185
173
  )(lhs, rhs, group_sizes)
186
174
 
187
175
  if rhs_bias is not None:
176
+ rhs_bias = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
177
+ gmm_result = (gmm_result + rhs_bias).astype(gmm_result.dtype)
188
178
 
189
- def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
190
- rhs_bias = jnp.repeat(
191
- rhs_bias_local,
192
- group_sizes_global,
193
- 0,
194
- total_repeat_length=gmm_result_local.shape[0])
195
- return gmm_result_local + rhs_bias
196
-
197
- gmm_result = shard_map(
198
- _add_bias,
199
- mesh=mesh,
200
- in_specs=(P("data"), P(), P("data")),
201
- out_specs=(P("data")),
202
- )(gmm_result, rhs_bias, group_sizes)
203
-
204
- return gmm_result.astype(lhs.dtype)
179
+ return gmm_result
205
180
 
206
181
 
207
182
  def expert_sharded_gmm(
208
183
  lhs: jax.Array,
209
184
  rhs: jax.Array,
210
185
  group_sizes: jax.Array,
186
+ transpose_rhs: bool,
211
187
  mesh: Mesh,
188
+ num_experts: int,
189
+ ep_size: int,
212
190
  ) -> jax.Array:
213
- ep_size = mesh.shape["model"]
191
+ # adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
192
+ m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
193
+ n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
194
+ tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
214
195
 
215
- num_experts = rhs.shape[0]
216
196
  num_experts_per_shard = num_experts // ep_size
217
197
  group_offset = jnp.arange(0, num_experts, num_experts_per_shard)
198
+ group_offset = jax.lax.with_sharding_constraint(
199
+ group_offset, NamedSharding(mesh, P("model")))
218
200
 
219
201
  def _gmm(lhs, rhs, group_sizes, group_offset):
220
- m, g, n, k = lhs.shape[0], *rhs.shape
221
- tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
222
-
202
+ # Group offset for this shard. `group_offset` is sharded, and in this
203
+ # sharded function, it has only 1 element and `group_offset.shape` is
204
+ # (1,) but gmm kernel requires the group_offset to be a ()-shaped array,
205
+ # so we group_offset[0].
206
+ group_offset_of_shard = group_offset[0]
223
207
  gmm_res = gmm(
224
208
  lhs=lhs,
225
209
  rhs=rhs,
226
210
  group_sizes=group_sizes,
227
211
  preferred_element_type=lhs.dtype,
228
212
  tiling=(tm, tk, tn),
229
- transpose_rhs=True,
230
- group_offset=group_offset[0],
213
+ transpose_rhs=transpose_rhs,
214
+ group_offset=group_offset_of_shard,
231
215
  )
232
216
  return gmm_res
233
217
 
@@ -254,24 +238,30 @@ def expert_sharded_gmm(
254
238
  mesh=mesh,
255
239
  in_specs=(P(), P("model", None, None), P(), P("model")),
256
240
  out_specs=(P("model", None)),
257
- check_vma=False,
241
+ check_rep=False,
258
242
  )(lhs, rhs, group_sizes, group_offset)
259
243
 
260
244
  # For i-th shard, it is responsible groups (AKA experts) from
261
245
  # i*num_experts_per_shard to (i+1)*num_experts_per_shard We sum them up to
262
246
  # get total rows in that shard, and that is the size for shard to send to
263
247
  # its peers. This is also the number of non-zero rows from the gmm results.
264
- # In the working example, send_sizes would be [3, 2, 5, 4].
265
-
266
- # group_sizes has shape of [num_tokens_per_shard * num_experts_per_shard].
267
- # So reshaping to [num_tokens_per_shard, num_experts_per_shard] and applying
268
- # sum(axis=1) will get desired send_sizes shaped [num_tokens_per_shard].
269
- send_sizes = group_sizes.reshape(-1, num_experts_per_shard).sum(axis=1)
248
+ # In the working example, send_sizes would be [3, 2, 5, 4]
249
+ send_sizes = jnp.array([
250
+ group_sizes[i * num_experts_per_shard:(i + 1) *
251
+ num_experts_per_shard].sum() for i in range(ep_size)
252
+ ])
270
253
  # In the working example, input_offsets would be [0, 3, 5, 10]
271
254
  input_offsets = jnp.concatenate((jnp.array([0]), send_sizes.cumsum()[:-1]))
272
255
  output_offsets = input_offsets
273
256
  recv_sizes = send_sizes
274
257
 
258
+ input_offsets = jax.lax.with_sharding_constraint(
259
+ input_offsets, NamedSharding(mesh, P("model")))
260
+ send_sizes = jax.lax.with_sharding_constraint(
261
+ send_sizes, NamedSharding(mesh, P("model")))
262
+ output_offsets = jax.lax.with_sharding_constraint(
263
+ output_offsets, NamedSharding(mesh, P("model")))
264
+
275
265
  def _ragged_all_to_all(operand, input_offsets, send_sizes, output_offsets,
276
266
  recv_sizes):
277
267
  output = jnp.zeros_like(operand)
@@ -326,20 +316,10 @@ def expert_sharded_gmm(
326
316
  mesh=mesh,
327
317
  in_specs=(P("model", None), P("model"), P("model"), P("model"), P()),
328
318
  out_specs=(P()),
329
- check_vma=False,
319
+ check_rep=False,
330
320
  )(gmm_res, input_offsets, send_sizes, output_offsets, recv_sizes)
331
321
 
332
322
 
333
- @functools.partial(
334
- jax.jit,
335
- static_argnames=(
336
- "topk",
337
- "renormalize",
338
- "mesh",
339
- "use_ep",
340
- "activation",
341
- ),
342
- )
343
323
  def fused_moe_func(
344
324
  hidden_states: jax.Array,
345
325
  w1: jax.Array,
@@ -348,45 +328,37 @@ def fused_moe_func(
348
328
  w2_bias: jax.Array | None,
349
329
  gating_output: jax.Array,
350
330
  topk: int,
331
+ global_num_experts: int,
351
332
  renormalize: bool,
333
+ reduce_results: bool,
352
334
  mesh: Mesh,
353
335
  use_ep: bool,
354
336
  activation: str,
355
- ) -> jax.Array:
337
+ ):
356
338
  """
357
- Route tokens in hidden_states into each experts based on routing
358
- information in gating_out and performs moe with w1 and w2 weights.
359
-
360
339
  Args:
361
- hidden_states: [num_tokens, hidden_size]
362
- w1: first moe weights [num_experts, intermediate_size * 2, hidden_size]
363
- w2: second moe weights [num_experts, hidden_size, intermediate_size]
364
- w1_bias: optional bias of w1 [num_experts, intermediate_size * 2]
365
- w2_bias: optional bias of w2 [num_experts, hidden_size]
366
- gating_output: routing information of tokens [num_tokens, num_experts]
367
- topk: number of experts to choose per token.
368
- renormalize: normalize gating_output.
369
- mesh: mesh to perform moe.
370
- use_ep: use expert parallelism.
371
- activation: activation function to perform on the output of w1.
372
-
373
- Returns:
374
- Output of moe operation [num_tokens, hidden_size]
340
+ hidden_states: [*, hidden_size]
341
+ w1: [num_experts, intermediate_size * 2, hidden_size]
342
+ w2: [num_experts, hidden_size, intermediate_size]
343
+ gating_output: [*, num_experts]
375
344
  """
345
+ # adapted from https://github.com/vllm-project/vllm/blob/29fa5cac1cd731026f59084d93a822921507573c/vllm/model_executor/layers/fused_moe/moe_pallas.py#L26
376
346
  if use_ep and (w1_bias is not None or w2_bias is not None):
377
347
  raise NotImplementedError(
378
348
  "Bias is not supported when using expert parallelism.")
379
-
380
- num_tokens = hidden_states.shape[0]
381
- global_num_experts, hidden_size, intermediate_size = w2.shape
349
+ orig_shape = hidden_states.shape
350
+ hidden_size = hidden_states.shape[-1]
351
+ num_tokens = hidden_states.size // hidden_size
352
+ assert global_num_experts == w1.shape[0]
353
+ ep_size = mesh.shape["model"] # only used if use_ep is True.
354
+ intermediate_size = w2.shape[-1]
382
355
  dtype = hidden_states.dtype
383
-
384
356
  assert (num_tokens * topk) % 16 == 0, (
385
357
  "The kernel requires num_tokens * topk to be a multiple of "
386
358
  f"16 but got {num_tokens}*{topk}={num_tokens*topk}")
387
- assert hidden_states.shape == (num_tokens, hidden_size)
388
- assert gating_output.shape == (num_tokens, global_num_experts)
389
- assert w1.shape == (global_num_experts, intermediate_size * 2, hidden_size)
359
+
360
+ hidden_states = hidden_states.reshape(num_tokens, hidden_size)
361
+ gating_output = gating_output.reshape(num_tokens, global_num_experts)
390
362
 
391
363
  topk_weights = jax.nn.softmax(gating_output.astype(jnp.float32), axis=-1)
392
364
  topk_weights, topk_indices = jax.lax.top_k(topk_weights, k=topk)
@@ -394,76 +366,142 @@ def fused_moe_func(
394
366
  topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdims=True)
395
367
  topk_weights = topk_weights.astype(dtype)
396
368
 
397
- def _process_tokens_locally(hidden_states_local, topk_indices_local):
398
- num_tokens_local = hidden_states_local.shape[0]
399
- topk_indices_flat = topk_indices_local.flatten()
400
- topk_argsort_indices = jnp.argsort(topk_indices_flat)
401
- topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
402
- token_indices = jnp.arange(num_tokens_local,
403
- dtype=jnp.int32).repeat(topk)
404
- token_indices_sorted = token_indices[topk_argsort_indices]
405
- group_sizes_local = jnp.bincount(topk_indices_flat,
406
- length=global_num_experts)
407
-
408
- x = hidden_states_local[token_indices_sorted]
409
- return x, group_sizes_local, topk_argsort_revert_indices
410
-
411
- x, group_sizes, topk_argsort_revert_indices = shard_map(
412
- _process_tokens_locally,
413
- mesh=mesh,
414
- in_specs=(P("data", None), P("data", None)),
415
- out_specs=(P("data", None), P("data"), P("data")),
416
- )(hidden_states, topk_indices)
369
+ topk_indices_flat = topk_indices.flatten()
370
+ topk_argsort_indices = jnp.argsort(topk_indices_flat)
371
+ topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
372
+ token_indices = jnp.arange(num_tokens, dtype=jnp.int32).repeat(topk)
373
+ token_indices_sorted = token_indices[topk_argsort_indices]
374
+ group_sizes = jnp.bincount(topk_indices_flat, length=global_num_experts)
375
+
376
+ x = hidden_states[token_indices_sorted]
417
377
 
418
378
  if use_ep:
419
379
  x = expert_sharded_gmm(
420
380
  x,
421
381
  w1,
422
382
  group_sizes,
383
+ transpose_rhs=True,
423
384
  mesh=mesh,
385
+ num_experts=global_num_experts,
386
+ ep_size=ep_size,
424
387
  )
425
- x1, x2 = jnp.split(x, 2, -1)
426
-
427
- x = activation_fn(activation, x1, x2)
428
-
429
- x = expert_sharded_gmm(
430
- x,
431
- w2,
432
- group_sizes,
433
- mesh=mesh,
434
- )
388
+ x1, x2 = x[..., :intermediate_size], x[..., intermediate_size:]
435
389
  else:
436
390
  x1, x2 = tensor_sharded_gmm_merged_column_parallel(
437
391
  x,
438
392
  w1,
439
393
  w1_bias,
440
394
  group_sizes,
395
+ transpose_rhs=True,
441
396
  mesh=mesh,
397
+ intermediate_size=intermediate_size,
442
398
  )
443
399
 
444
- x = activation_fn(activation, x1, x2)
400
+ x = activation_fn(activation, x1, x2)
445
401
 
402
+ if use_ep:
403
+ x = expert_sharded_gmm(
404
+ x,
405
+ w2,
406
+ group_sizes,
407
+ transpose_rhs=True,
408
+ mesh=mesh,
409
+ num_experts=global_num_experts,
410
+ ep_size=ep_size,
411
+ )
412
+ else:
413
+ x = jax.lax.with_sharding_constraint(
414
+ x, NamedSharding(mesh, P(None, "model")))
446
415
  x = tensor_sharded_gmm_row_parallel(
447
416
  x,
448
417
  w2,
449
418
  w2_bias,
450
419
  group_sizes,
420
+ transpose_rhs=True,
451
421
  mesh=mesh,
452
422
  )
453
423
 
454
- def _finalize_output(x_local, topk_argsort_revert_indices_local,
455
- topk_weights_local):
456
- x_local = x_local[topk_argsort_revert_indices_local].reshape(
457
- -1, topk, hidden_size)
458
- x_local = x_local * jnp.expand_dims(topk_weights_local, axis=-1)
459
- x_local = x_local.sum(axis=-2)
460
- return x_local
424
+ x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
425
+ x = x * jnp.expand_dims(topk_weights, axis=-1)
426
+ x = x.sum(axis=-2)
427
+ x = x.reshape(orig_shape)
461
428
 
462
- x = shard_map(
463
- _finalize_output,
464
- mesh=mesh,
465
- in_specs=(P("data", None), P("data"), P("data", None)),
466
- out_specs=(P("data", None)),
467
- )(x, topk_argsort_revert_indices, topk_weights)
429
+ if reduce_results:
430
+ x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
431
+ return x
468
432
 
469
- return x[:num_tokens, :hidden_size]
433
+
434
+ @functools.partial(
435
+ jax.jit,
436
+ static_argnames=(
437
+ "topk",
438
+ "global_num_experts",
439
+ "renormalize",
440
+ "reduce_results",
441
+ "mesh",
442
+ "use_ep",
443
+ "activation",
444
+ ),
445
+ )
446
+ def fused_moe_func_padded(
447
+ hidden_states: jax.Array,
448
+ w1: jax.Array,
449
+ w2: jax.Array,
450
+ w1_bias: jax.Array | None,
451
+ w2_bias: jax.Array | None,
452
+ gating_output: jax.Array,
453
+ topk: int,
454
+ global_num_experts: int,
455
+ renormalize: bool,
456
+ reduce_results: bool,
457
+ mesh: Mesh,
458
+ use_ep: bool,
459
+ activation: str,
460
+ ):
461
+ # TODO(fanhongmin@google.com): Once the jax runner pads the input, we no longer need this.
462
+ hidden_size = hidden_states.shape[-1]
463
+ num_tokens = hidden_states.size // hidden_size
464
+ if num_tokens * topk < 16:
465
+ assert 16 % (num_tokens *
466
+ topk) == 0, f"Cannot pad to 16: {num_tokens=}, {topk=}"
467
+ n_repeats = 16 // (num_tokens * topk)
468
+
469
+ reps = (n_repeats, ) + (1, ) * (hidden_states.ndim - 1)
470
+ expanded_hidden_states = jnp.tile(hidden_states, reps)
471
+
472
+ reps = (n_repeats, ) + (1, ) * (gating_output.ndim - 1)
473
+ expanded_gating_output = jnp.tile(gating_output, reps)
474
+
475
+ expanded_x = fused_moe_func(
476
+ expanded_hidden_states,
477
+ w1,
478
+ w2,
479
+ w1_bias,
480
+ w2_bias,
481
+ expanded_gating_output,
482
+ topk,
483
+ global_num_experts,
484
+ renormalize,
485
+ reduce_results,
486
+ mesh,
487
+ use_ep,
488
+ activation,
489
+ )
490
+ x = expanded_x[:hidden_states.shape[0]]
491
+ return x
492
+ else:
493
+ return fused_moe_func(
494
+ hidden_states,
495
+ w1,
496
+ w2,
497
+ w1_bias,
498
+ w2_bias,
499
+ gating_output,
500
+ topk,
501
+ global_num_experts,
502
+ renormalize,
503
+ reduce_results,
504
+ mesh,
505
+ use_ep,
506
+ activation,
507
+ )
@@ -5,12 +5,10 @@ from vllm.config import VllmConfig
5
5
  from vllm.model_executor.layers.quantization.base_config import \
6
6
  QuantizationConfig
7
7
 
8
- from tpu_inference.layers.common import quant_methods
9
8
  from tpu_inference.layers.vllm.quantization.awq import VllmAWQConfig
10
9
  from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
11
10
  from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors import \
12
11
  VllmCompressedTensorsConfig # noqa: E501
13
- from tpu_inference.layers.vllm.quantization.mxfp4 import VllmMxfp4Config
14
12
  from tpu_inference.layers.vllm.quantization.unquantized import \
15
13
  VllmUnquantizedConfig
16
14
 
@@ -21,9 +19,8 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
21
19
  # TODO(kyuyeunk): Add support for "tpu_int8".
22
20
  method_to_config: dict[str, str] = {
23
21
  None: VllmUnquantizedConfig,
24
- quant_methods.COMPRESSED_TENSORS: VllmCompressedTensorsConfig,
25
- quant_methods.AWQ: VllmAWQConfig,
26
- quant_methods.MXFP4: VllmMxfp4Config,
22
+ "compressed-tensors": VllmCompressedTensorsConfig,
23
+ "awq": VllmAWQConfig,
27
24
  }
28
25
  if model_config.quantization not in method_to_config:
29
26
  raise NotImplementedError(
@@ -33,7 +30,6 @@ def get_tpu_quantization_config(vllm_config: VllmConfig,
33
30
  assert issubclass(quant_config, JaxCommonConfig)
34
31
  quant_config.set_configs(vllm_config, mesh)
35
32
 
36
- model_config.quantization = quant_methods.get_tpu_quant_method(
37
- quant_config.get_name())
33
+ model_config.quantization = quant_config.get_name()
38
34
  return VllmConfig.get_quantization_config(model_config,
39
35
  vllm_config.load_config)