tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (76) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -7
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/lora/utils.py +0 -8
  9. tests/test_envs.py +110 -12
  10. tests/test_quantization.py +3 -0
  11. tests/test_utils.py +1 -2
  12. tpu_inference/__init__.py +22 -3
  13. tpu_inference/core/disagg_utils.py +6 -8
  14. tpu_inference/distributed/tpu_connector.py +3 -4
  15. tpu_inference/distributed/utils.py +3 -2
  16. tpu_inference/envs.py +93 -9
  17. tpu_inference/executors/ray_distributed_executor.py +9 -2
  18. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  19. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  20. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  21. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  22. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  23. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  25. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
  26. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
  27. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  28. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  29. tpu_inference/layers/common/attention_interface.py +7 -1
  30. tpu_inference/layers/common/sharding.py +11 -7
  31. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  32. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  33. tpu_inference/layers/vllm/fused_moe.py +170 -208
  34. tpu_inference/layers/vllm/linear_common.py +43 -21
  35. tpu_inference/layers/vllm/quantization/common.py +11 -6
  36. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  38. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  39. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  40. tpu_inference/layers/vllm/sharding.py +2 -2
  41. tpu_inference/lora/torch_punica_tpu.py +1 -2
  42. tpu_inference/models/common/model_loader.py +84 -28
  43. tpu_inference/models/jax/deepseek_v3.py +185 -64
  44. tpu_inference/models/jax/gpt_oss.py +3 -3
  45. tpu_inference/models/jax/llama3.py +2 -1
  46. tpu_inference/models/jax/llama_eagle3.py +8 -5
  47. tpu_inference/models/jax/llama_guard_4.py +361 -0
  48. tpu_inference/models/jax/qwen2.py +2 -1
  49. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  50. tpu_inference/models/jax/qwen3.py +2 -1
  51. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  52. tpu_inference/models/jax/utils/weight_utils.py +205 -144
  53. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
  54. tpu_inference/platforms/tpu_platform.py +34 -50
  55. tpu_inference/runner/compilation_manager.py +144 -60
  56. tpu_inference/runner/kv_cache.py +40 -20
  57. tpu_inference/runner/kv_cache_manager.py +48 -33
  58. tpu_inference/runner/persistent_batch_manager.py +40 -2
  59. tpu_inference/runner/structured_decoding_manager.py +2 -3
  60. tpu_inference/runner/tpu_runner.py +280 -149
  61. tpu_inference/runner/utils.py +2 -2
  62. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  63. tpu_inference/tpu_info.py +4 -3
  64. tpu_inference/utils.py +46 -18
  65. tpu_inference/worker/tpu_worker.py +197 -63
  66. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
  67. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
  68. tpu_inference/mock/__init__.py +0 -0
  69. tpu_inference/mock/vllm_config_utils.py +0 -28
  70. tpu_inference/mock/vllm_envs.py +0 -1219
  71. tpu_inference/mock/vllm_logger.py +0 -212
  72. tpu_inference/mock/vllm_logging_utils.py +0 -15
  73. tpu_inference/models/jax/phi3.py +0 -376
  74. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  75. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  76. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
@@ -2,17 +2,16 @@ import functools
2
2
 
3
3
  import jax
4
4
  from jax import numpy as jnp
5
+ from jax import shard_map
5
6
  from jax.experimental.pallas.ops.tpu.megablox.gmm import gmm
6
- from jax.experimental.shard_map import shard_map
7
- from jax.sharding import Mesh, NamedSharding, PartitionSpec
7
+ from jax.sharding import Mesh
8
+ from jax.sharding import PartitionSpec as P
8
9
 
9
10
  from tpu_inference.layers.vllm.linear_common import \
10
11
  slice_sharded_tensor_for_concatenation
11
12
 
12
- P = PartitionSpec
13
13
 
14
-
15
- def activation_fn(activation: str, x1, x2):
14
+ def activation_fn(activation: str, x1: jax.Array, x2: jax.Array) -> jax.Array:
16
15
  match activation:
17
16
  case "silu":
18
17
  return jax.nn.silu(x1) * x2
@@ -23,7 +22,10 @@ def activation_fn(activation: str, x1, x2):
23
22
  f"FusedMoE does not support {activation} activation")
24
23
 
25
24
 
26
- def _swigluoai(x1, x2, alpha=1.702, limit=7.0):
25
+ def _swigluoai(x1: jax.Array,
26
+ x2: jax.Array,
27
+ alpha=1.702,
28
+ limit=7.0) -> jax.Array:
27
29
  x1 = jnp.clip(x1, a_max=limit)
28
30
  x2 = jnp.clip(x2, a_min=-limit, a_max=limit)
29
31
 
@@ -103,40 +105,53 @@ def tensor_sharded_gmm_merged_column_parallel(
103
105
  rhs: jax.Array,
104
106
  rhs_bias: jax.Array | None,
105
107
  group_sizes: jax.Array,
106
- transpose_rhs: bool,
107
108
  mesh: Mesh,
108
- intermediate_size: int,
109
- ) -> jax.Array:
110
- # adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
111
- m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
112
- n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
113
- tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
114
-
115
- _gmm = functools.partial(
116
- gmm,
117
- preferred_element_type=lhs.dtype,
118
- tiling=(tm, tk, tn),
119
- transpose_rhs=transpose_rhs,
120
- group_offset=jnp.array(0),
121
- )
109
+ ) -> tuple[jax.Array, jax.Array]:
110
+
111
+ def _gmm(lhs, rhs, group_sizes):
112
+ m, g, n, k = lhs.shape[0], *rhs.shape
113
+ tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
114
+ return gmm(
115
+ lhs,
116
+ rhs,
117
+ group_sizes,
118
+ preferred_element_type=lhs.dtype,
119
+ tiling=(tm, tk, tn),
120
+ transpose_rhs=True,
121
+ group_offset=jnp.array(0),
122
+ )
122
123
 
123
124
  gmm_result = shard_map(
124
125
  _gmm,
125
126
  mesh=mesh,
126
- in_specs=(P(), P(None, "model", None), P()),
127
- out_specs=(P(None, "model")),
128
- check_rep=False,
127
+ in_specs=(P("data", None), P(None, "model", None), P("data")),
128
+ out_specs=(P("data", "model")),
129
+ check_vma=False,
129
130
  )(lhs, rhs, group_sizes)
130
131
 
131
132
  if rhs_bias is not None:
132
- rhs_bis = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
133
- gmm_result = (gmm_result + rhs_bis).astype(gmm_result.dtype)
134
133
 
135
- n_shards = mesh.shape["model"]
136
- output_sizes = [intermediate_size, intermediate_size]
134
+ def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
135
+ rhs_bias = jnp.repeat(
136
+ rhs_bias_local,
137
+ group_sizes_global,
138
+ 0,
139
+ total_repeat_length=gmm_result_local.shape[0])
140
+ return gmm_result_local + rhs_bias
141
+
142
+ gmm_result = shard_map(
143
+ _add_bias,
144
+ mesh=mesh,
145
+ in_specs=(P("data", "model"), P(None, "model"), P("data")),
146
+ out_specs=(P("data", "model")),
147
+ )(gmm_result, rhs_bias, group_sizes)
148
+ gmm_result = gmm_result.astype(lhs.dtype)
137
149
 
150
+ tp_size = mesh.shape["model"]
151
+ intermediate_size = gmm_result.shape[-1] // 2
152
+ output_sizes = [intermediate_size, intermediate_size]
138
153
  return slice_sharded_tensor_for_concatenation(gmm_result, output_sizes,
139
- n_shards)
154
+ tp_size)
140
155
 
141
156
 
142
157
  def tensor_sharded_gmm_row_parallel(
@@ -144,74 +159,75 @@ def tensor_sharded_gmm_row_parallel(
144
159
  rhs: jax.Array,
145
160
  rhs_bias: jax.Array | None,
146
161
  group_sizes: jax.Array,
147
- transpose_rhs: bool,
148
162
  mesh: Mesh,
149
163
  ) -> jax.Array:
150
- # adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
151
- m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
152
- n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
153
- tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
154
-
155
- _gmm = functools.partial(
156
- gmm,
157
- preferred_element_type=lhs.dtype,
158
- tiling=(tm, tk, tn),
159
- transpose_rhs=transpose_rhs,
160
- group_offset=jnp.array(0),
161
- )
162
164
 
163
165
  def _gmm_all_reduce(lhs, rhs, group_sizes):
164
- r = _gmm(lhs, rhs, group_sizes)
165
- return jax.lax.psum(r, axis_name="model")
166
+ m, g, n, k = lhs.shape[0], *rhs.shape
167
+ tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
168
+ out = gmm(
169
+ lhs,
170
+ rhs,
171
+ group_sizes,
172
+ preferred_element_type=lhs.dtype,
173
+ tiling=(tm, tk, tn),
174
+ transpose_rhs=True,
175
+ group_offset=jnp.array(0),
176
+ )
177
+ return jax.lax.psum(out, axis_name="model")
166
178
 
167
179
  gmm_result = shard_map(
168
180
  _gmm_all_reduce,
169
181
  mesh=mesh,
170
- in_specs=(P(None, "model"), P(None, None, "model"), P()),
171
- out_specs=(P()),
172
- check_rep=False,
182
+ in_specs=(P("data", "model"), P(None, None, "model"), P("data")),
183
+ out_specs=(P("data")),
184
+ check_vma=False,
173
185
  )(lhs, rhs, group_sizes)
174
186
 
175
187
  if rhs_bias is not None:
176
- rhs_bias = jnp.repeat(rhs_bias, group_sizes, 0, total_repeat_length=m)
177
- gmm_result = (gmm_result + rhs_bias).astype(gmm_result.dtype)
178
188
 
179
- return gmm_result
189
+ def _add_bias(gmm_result_local, rhs_bias_local, group_sizes_global):
190
+ rhs_bias = jnp.repeat(
191
+ rhs_bias_local,
192
+ group_sizes_global,
193
+ 0,
194
+ total_repeat_length=gmm_result_local.shape[0])
195
+ return gmm_result_local + rhs_bias
196
+
197
+ gmm_result = shard_map(
198
+ _add_bias,
199
+ mesh=mesh,
200
+ in_specs=(P("data"), P(), P("data")),
201
+ out_specs=(P("data")),
202
+ )(gmm_result, rhs_bias, group_sizes)
203
+
204
+ return gmm_result.astype(lhs.dtype)
180
205
 
181
206
 
182
207
  def expert_sharded_gmm(
183
208
  lhs: jax.Array,
184
209
  rhs: jax.Array,
185
210
  group_sizes: jax.Array,
186
- transpose_rhs: bool,
187
211
  mesh: Mesh,
188
- num_experts: int,
189
- ep_size: int,
190
212
  ) -> jax.Array:
191
- # adapted from https://github.com/pytorch/xla/blob/1d409399474197c484894be90b75d9855393dda5/torch_xla/experimental/custom_kernel.py#L1401
192
- m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
193
- n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
194
- tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
213
+ ep_size = mesh.shape["model"]
195
214
 
215
+ num_experts = rhs.shape[0]
196
216
  num_experts_per_shard = num_experts // ep_size
197
217
  group_offset = jnp.arange(0, num_experts, num_experts_per_shard)
198
- group_offset = jax.lax.with_sharding_constraint(
199
- group_offset, NamedSharding(mesh, P("model")))
200
218
 
201
219
  def _gmm(lhs, rhs, group_sizes, group_offset):
202
- # Group offset for this shard. `group_offset` is sharded, and in this
203
- # sharded function, it has only 1 element and `group_offset.shape` is
204
- # (1,) but gmm kernel requires the group_offset to be a ()-shaped array,
205
- # so we group_offset[0].
206
- group_offset_of_shard = group_offset[0]
220
+ m, g, n, k = lhs.shape[0], *rhs.shape
221
+ tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
222
+
207
223
  gmm_res = gmm(
208
224
  lhs=lhs,
209
225
  rhs=rhs,
210
226
  group_sizes=group_sizes,
211
227
  preferred_element_type=lhs.dtype,
212
228
  tiling=(tm, tk, tn),
213
- transpose_rhs=transpose_rhs,
214
- group_offset=group_offset_of_shard,
229
+ transpose_rhs=True,
230
+ group_offset=group_offset[0],
215
231
  )
216
232
  return gmm_res
217
233
 
@@ -238,30 +254,24 @@ def expert_sharded_gmm(
238
254
  mesh=mesh,
239
255
  in_specs=(P(), P("model", None, None), P(), P("model")),
240
256
  out_specs=(P("model", None)),
241
- check_rep=False,
257
+ check_vma=False,
242
258
  )(lhs, rhs, group_sizes, group_offset)
243
259
 
244
260
  # For i-th shard, it is responsible groups (AKA experts) from
245
261
  # i*num_experts_per_shard to (i+1)*num_experts_per_shard We sum them up to
246
262
  # get total rows in that shard, and that is the size for shard to send to
247
263
  # its peers. This is also the number of non-zero rows from the gmm results.
248
- # In the working example, send_sizes would be [3, 2, 5, 4]
249
- send_sizes = jnp.array([
250
- group_sizes[i * num_experts_per_shard:(i + 1) *
251
- num_experts_per_shard].sum() for i in range(ep_size)
252
- ])
264
+ # In the working example, send_sizes would be [3, 2, 5, 4].
265
+
266
+ # group_sizes has shape of [num_tokens_per_shard * num_experts_per_shard].
267
+ # So reshaping to [num_tokens_per_shard, num_experts_per_shard] and applying
268
+ # sum(axis=1) will get desired send_sizes shaped [num_tokens_per_shard].
269
+ send_sizes = group_sizes.reshape(-1, num_experts_per_shard).sum(axis=1)
253
270
  # In the working example, input_offsets would be [0, 3, 5, 10]
254
271
  input_offsets = jnp.concatenate((jnp.array([0]), send_sizes.cumsum()[:-1]))
255
272
  output_offsets = input_offsets
256
273
  recv_sizes = send_sizes
257
274
 
258
- input_offsets = jax.lax.with_sharding_constraint(
259
- input_offsets, NamedSharding(mesh, P("model")))
260
- send_sizes = jax.lax.with_sharding_constraint(
261
- send_sizes, NamedSharding(mesh, P("model")))
262
- output_offsets = jax.lax.with_sharding_constraint(
263
- output_offsets, NamedSharding(mesh, P("model")))
264
-
265
275
  def _ragged_all_to_all(operand, input_offsets, send_sizes, output_offsets,
266
276
  recv_sizes):
267
277
  output = jnp.zeros_like(operand)
@@ -316,10 +326,20 @@ def expert_sharded_gmm(
316
326
  mesh=mesh,
317
327
  in_specs=(P("model", None), P("model"), P("model"), P("model"), P()),
318
328
  out_specs=(P()),
319
- check_rep=False,
329
+ check_vma=False,
320
330
  )(gmm_res, input_offsets, send_sizes, output_offsets, recv_sizes)
321
331
 
322
332
 
333
+ @functools.partial(
334
+ jax.jit,
335
+ static_argnames=(
336
+ "topk",
337
+ "renormalize",
338
+ "mesh",
339
+ "use_ep",
340
+ "activation",
341
+ ),
342
+ )
323
343
  def fused_moe_func(
324
344
  hidden_states: jax.Array,
325
345
  w1: jax.Array,
@@ -328,37 +348,45 @@ def fused_moe_func(
328
348
  w2_bias: jax.Array | None,
329
349
  gating_output: jax.Array,
330
350
  topk: int,
331
- global_num_experts: int,
332
351
  renormalize: bool,
333
- reduce_results: bool,
334
352
  mesh: Mesh,
335
353
  use_ep: bool,
336
354
  activation: str,
337
- ):
355
+ ) -> jax.Array:
338
356
  """
357
+ Route tokens in hidden_states into each experts based on routing
358
+ information in gating_out and performs moe with w1 and w2 weights.
359
+
339
360
  Args:
340
- hidden_states: [*, hidden_size]
341
- w1: [num_experts, intermediate_size * 2, hidden_size]
342
- w2: [num_experts, hidden_size, intermediate_size]
343
- gating_output: [*, num_experts]
361
+ hidden_states: [num_tokens, hidden_size]
362
+ w1: first moe weights [num_experts, intermediate_size * 2, hidden_size]
363
+ w2: second moe weights [num_experts, hidden_size, intermediate_size]
364
+ w1_bias: optional bias of w1 [num_experts, intermediate_size * 2]
365
+ w2_bias: optional bias of w2 [num_experts, hidden_size]
366
+ gating_output: routing information of tokens [num_tokens, num_experts]
367
+ topk: number of experts to choose per token.
368
+ renormalize: normalize gating_output.
369
+ mesh: mesh to perform moe.
370
+ use_ep: use expert parallelism.
371
+ activation: activation function to perform on the output of w1.
372
+
373
+ Returns:
374
+ Output of moe operation [num_tokens, hidden_size]
344
375
  """
345
- # adapted from https://github.com/vllm-project/vllm/blob/29fa5cac1cd731026f59084d93a822921507573c/vllm/model_executor/layers/fused_moe/moe_pallas.py#L26
346
376
  if use_ep and (w1_bias is not None or w2_bias is not None):
347
377
  raise NotImplementedError(
348
378
  "Bias is not supported when using expert parallelism.")
349
- orig_shape = hidden_states.shape
350
- hidden_size = hidden_states.shape[-1]
351
- num_tokens = hidden_states.size // hidden_size
352
- assert global_num_experts == w1.shape[0]
353
- ep_size = mesh.shape["model"] # only used if use_ep is True.
354
- intermediate_size = w2.shape[-1]
379
+
380
+ num_tokens = hidden_states.shape[0]
381
+ global_num_experts, hidden_size, intermediate_size = w2.shape
355
382
  dtype = hidden_states.dtype
383
+
356
384
  assert (num_tokens * topk) % 16 == 0, (
357
385
  "The kernel requires num_tokens * topk to be a multiple of "
358
386
  f"16 but got {num_tokens}*{topk}={num_tokens*topk}")
359
-
360
- hidden_states = hidden_states.reshape(num_tokens, hidden_size)
361
- gating_output = gating_output.reshape(num_tokens, global_num_experts)
387
+ assert hidden_states.shape == (num_tokens, hidden_size)
388
+ assert gating_output.shape == (num_tokens, global_num_experts)
389
+ assert w1.shape == (global_num_experts, intermediate_size * 2, hidden_size)
362
390
 
363
391
  topk_weights = jax.nn.softmax(gating_output.astype(jnp.float32), axis=-1)
364
392
  topk_weights, topk_indices = jax.lax.top_k(topk_weights, k=topk)
@@ -366,142 +394,76 @@ def fused_moe_func(
366
394
  topk_weights = topk_weights / topk_weights.sum(axis=-1, keepdims=True)
367
395
  topk_weights = topk_weights.astype(dtype)
368
396
 
369
- topk_indices_flat = topk_indices.flatten()
370
- topk_argsort_indices = jnp.argsort(topk_indices_flat)
371
- topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
372
- token_indices = jnp.arange(num_tokens, dtype=jnp.int32).repeat(topk)
373
- token_indices_sorted = token_indices[topk_argsort_indices]
374
- group_sizes = jnp.bincount(topk_indices_flat, length=global_num_experts)
375
-
376
- x = hidden_states[token_indices_sorted]
397
+ def _process_tokens_locally(hidden_states_local, topk_indices_local):
398
+ num_tokens_local = hidden_states_local.shape[0]
399
+ topk_indices_flat = topk_indices_local.flatten()
400
+ topk_argsort_indices = jnp.argsort(topk_indices_flat)
401
+ topk_argsort_revert_indices = jnp.argsort(topk_argsort_indices)
402
+ token_indices = jnp.arange(num_tokens_local,
403
+ dtype=jnp.int32).repeat(topk)
404
+ token_indices_sorted = token_indices[topk_argsort_indices]
405
+ group_sizes_local = jnp.bincount(topk_indices_flat,
406
+ length=global_num_experts)
407
+
408
+ x = hidden_states_local[token_indices_sorted]
409
+ return x, group_sizes_local, topk_argsort_revert_indices
410
+
411
+ x, group_sizes, topk_argsort_revert_indices = shard_map(
412
+ _process_tokens_locally,
413
+ mesh=mesh,
414
+ in_specs=(P("data", None), P("data", None)),
415
+ out_specs=(P("data", None), P("data"), P("data")),
416
+ )(hidden_states, topk_indices)
377
417
 
378
418
  if use_ep:
379
419
  x = expert_sharded_gmm(
380
420
  x,
381
421
  w1,
382
422
  group_sizes,
383
- transpose_rhs=True,
384
423
  mesh=mesh,
385
- num_experts=global_num_experts,
386
- ep_size=ep_size,
387
424
  )
388
- x1, x2 = x[..., :intermediate_size], x[..., intermediate_size:]
425
+ x1, x2 = jnp.split(x, 2, -1)
426
+
427
+ x = activation_fn(activation, x1, x2)
428
+
429
+ x = expert_sharded_gmm(
430
+ x,
431
+ w2,
432
+ group_sizes,
433
+ mesh=mesh,
434
+ )
389
435
  else:
390
436
  x1, x2 = tensor_sharded_gmm_merged_column_parallel(
391
437
  x,
392
438
  w1,
393
439
  w1_bias,
394
440
  group_sizes,
395
- transpose_rhs=True,
396
441
  mesh=mesh,
397
- intermediate_size=intermediate_size,
398
442
  )
399
443
 
400
- x = activation_fn(activation, x1, x2)
444
+ x = activation_fn(activation, x1, x2)
401
445
 
402
- if use_ep:
403
- x = expert_sharded_gmm(
404
- x,
405
- w2,
406
- group_sizes,
407
- transpose_rhs=True,
408
- mesh=mesh,
409
- num_experts=global_num_experts,
410
- ep_size=ep_size,
411
- )
412
- else:
413
- x = jax.lax.with_sharding_constraint(
414
- x, NamedSharding(mesh, P(None, "model")))
415
446
  x = tensor_sharded_gmm_row_parallel(
416
447
  x,
417
448
  w2,
418
449
  w2_bias,
419
450
  group_sizes,
420
- transpose_rhs=True,
421
451
  mesh=mesh,
422
452
  )
423
453
 
424
- x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
425
- x = x * jnp.expand_dims(topk_weights, axis=-1)
426
- x = x.sum(axis=-2)
427
- x = x.reshape(orig_shape)
428
-
429
- if reduce_results:
430
- x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
431
- return x
454
+ def _finalize_output(x_local, topk_argsort_revert_indices_local,
455
+ topk_weights_local):
456
+ x_local = x_local[topk_argsort_revert_indices_local].reshape(
457
+ -1, topk, hidden_size)
458
+ x_local = x_local * jnp.expand_dims(topk_weights_local, axis=-1)
459
+ x_local = x_local.sum(axis=-2)
460
+ return x_local
432
461
 
462
+ x = shard_map(
463
+ _finalize_output,
464
+ mesh=mesh,
465
+ in_specs=(P("data", None), P("data"), P("data", None)),
466
+ out_specs=(P("data", None)),
467
+ )(x, topk_argsort_revert_indices, topk_weights)
433
468
 
434
- @functools.partial(
435
- jax.jit,
436
- static_argnames=(
437
- "topk",
438
- "global_num_experts",
439
- "renormalize",
440
- "reduce_results",
441
- "mesh",
442
- "use_ep",
443
- "activation",
444
- ),
445
- )
446
- def fused_moe_func_padded(
447
- hidden_states: jax.Array,
448
- w1: jax.Array,
449
- w2: jax.Array,
450
- w1_bias: jax.Array | None,
451
- w2_bias: jax.Array | None,
452
- gating_output: jax.Array,
453
- topk: int,
454
- global_num_experts: int,
455
- renormalize: bool,
456
- reduce_results: bool,
457
- mesh: Mesh,
458
- use_ep: bool,
459
- activation: str,
460
- ):
461
- # TODO(fanhongmin@google.com): Once the jax runner pads the input, we no longer need this.
462
- hidden_size = hidden_states.shape[-1]
463
- num_tokens = hidden_states.size // hidden_size
464
- if num_tokens * topk < 16:
465
- assert 16 % (num_tokens *
466
- topk) == 0, f"Cannot pad to 16: {num_tokens=}, {topk=}"
467
- n_repeats = 16 // (num_tokens * topk)
468
-
469
- reps = (n_repeats, ) + (1, ) * (hidden_states.ndim - 1)
470
- expanded_hidden_states = jnp.tile(hidden_states, reps)
471
-
472
- reps = (n_repeats, ) + (1, ) * (gating_output.ndim - 1)
473
- expanded_gating_output = jnp.tile(gating_output, reps)
474
-
475
- expanded_x = fused_moe_func(
476
- expanded_hidden_states,
477
- w1,
478
- w2,
479
- w1_bias,
480
- w2_bias,
481
- expanded_gating_output,
482
- topk,
483
- global_num_experts,
484
- renormalize,
485
- reduce_results,
486
- mesh,
487
- use_ep,
488
- activation,
489
- )
490
- x = expanded_x[:hidden_states.shape[0]]
491
- return x
492
- else:
493
- return fused_moe_func(
494
- hidden_states,
495
- w1,
496
- w2,
497
- w1_bias,
498
- w2_bias,
499
- gating_output,
500
- topk,
501
- global_num_experts,
502
- renormalize,
503
- reduce_results,
504
- mesh,
505
- use_ep,
506
- activation,
507
- )
469
+ return x[:num_tokens, :hidden_size]
@@ -9,30 +9,52 @@ from jax.sharding import PartitionSpec as P
9
9
  from torchax.interop import torch_view
10
10
  from torchax.ops.mappings import t2j
11
11
 
12
- from tpu_inference.kernels.quantized_matmul.kernel import \
13
- quantized_matmul_kernel
12
+ from tpu_inference import envs
13
+ from tpu_inference.kernels.quantized_matmul.kernel import (
14
+ quantized_matmul_kernel, xla_quantized_matmul)
14
15
 
15
16
 
16
17
  def sharded_quantized_matmul(x: jax.Array, w_q: jax.Array, w_s: jax.Array,
17
- mesh: Mesh, weight_sharding: P):
18
- out_axis, in_axis = weight_sharding
19
- x_sharding = P(None, in_axis)
20
- scale_sharding = P(out_axis, )
21
- out_sharding = P(None, out_axis)
22
-
23
- x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, x_sharding))
24
-
25
- def wrapper(x, w_q, w_s):
26
- output = quantized_matmul_kernel(x, w_q, w_s, x_q_dtype=w_q.dtype)
27
- if in_axis:
28
- output = jax.lax.psum(output, axis_name=in_axis)
29
- return output
30
-
31
- return shard_map(wrapper,
32
- mesh=mesh,
33
- in_specs=(x_sharding, weight_sharding, scale_sharding),
34
- out_specs=(out_sharding),
35
- check_rep=False)(x, w_q, w_s)
18
+ mesh: Mesh, weight_sharding: P) -> jax.Array:
19
+ """
20
+ Wrapper around the quantized matmul kernel.
21
+
22
+ Args:
23
+ x: Activation.
24
+ w_q: Weight quantized array. [n_output_features, n_input_features]
25
+ w_s: Weight quantization scale. [n_output_features]
26
+ mesh: Mesh to shard on.
27
+ weight_sharding: PartitionSpec for the weight tensor.
28
+
29
+ Returns:
30
+ Output of the quantized matmul.
31
+ """
32
+
33
+ # NOTE (jacobplatin/kyuyeunk) there have been numeric issues (concerning) NaNs
34
+ # with the kernel and thus we disable it for now.
35
+ if envs.ENABLE_QUANTIZED_MATMUL_KERNEL:
36
+ out_axis, in_axis = weight_sharding
37
+ x_sharding = P(None, in_axis)
38
+ scale_sharding = P(out_axis, )
39
+ out_sharding = P(None, out_axis)
40
+
41
+ x = jax.lax.with_sharding_constraint(x,
42
+ NamedSharding(mesh, x_sharding))
43
+
44
+ def wrapper(x, w_q, w_s):
45
+ output = quantized_matmul_kernel(x, w_q, w_s, x_q_dtype=w_q.dtype)
46
+ if in_axis:
47
+ output = jax.lax.psum(output, axis_name=in_axis)
48
+ return output
49
+
50
+ return shard_map(wrapper,
51
+ mesh=mesh,
52
+ in_specs=(x_sharding, weight_sharding,
53
+ scale_sharding),
54
+ out_specs=(out_sharding),
55
+ check_rep=False)(x, w_q, w_s)
56
+ else:
57
+ return xla_quantized_matmul(x, w_q, w_s)
36
58
 
37
59
 
38
60
  def reorder_concatenated_tensor_for_sharding(concatenated_tensor: jax.Array,
@@ -31,17 +31,17 @@ class JaxCommonLinearConfig:
31
31
  self.output_sizes = [layer.output_size]
32
32
  self.weight_sharding = P(None, None)
33
33
  self.fuse_matmuls = True
34
- self.enable_sequence_parallelism = vllm_config.compilation_config.pass_config.enable_sequence_parallelism
34
+ self.enable_sp = vllm_config.compilation_config.pass_config.enable_sp
35
35
  self.input_sharding = None
36
36
  self.output_sharding = None
37
37
 
38
38
  if isinstance(layer, RowParallelLinear):
39
39
  self.weight_sharding = P(None, "model")
40
- if self.enable_sequence_parallelism:
40
+ if self.enable_sp:
41
41
  self.output_sharding = P("model", None)
42
42
  elif isinstance(layer, ColumnParallelLinear):
43
43
  self.weight_sharding = P("model", None)
44
- if self.enable_sequence_parallelism:
44
+ if self.enable_sp:
45
45
  self.input_sharding = P("model", None)
46
46
 
47
47
  if isinstance(layer, MergedColumnParallelLinear) or isinstance(
@@ -61,10 +61,15 @@ class JaxCommonLinearConfig:
61
61
  " bad performance.", type(layer))
62
62
 
63
63
  self.bias_sharding = P(self.weight_sharding[0])
64
- self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
64
+ if isinstance(self.weight_sharding[0], tuple):
65
+ self.n_shards = 1
66
+ for axis in self.weight_sharding[0]:
67
+ self.n_shards *= self.mesh.shape.get(axis, 1)
68
+ else:
69
+ self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
65
70
 
66
71
  def get_input_sharding(self, x: torchax.tensor.Tensor):
67
- if self.enable_sequence_parallelism:
72
+ if self.enable_sp:
68
73
  token_num = x.shape[0]
69
74
  # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
70
75
  if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
@@ -74,7 +79,7 @@ class JaxCommonLinearConfig:
74
79
  return self.input_sharding
75
80
 
76
81
  def get_output_sharding(self, x: torchax.tensor.Tensor):
77
- if self.enable_sequence_parallelism:
82
+ if self.enable_sp:
78
83
  token_num = x.shape[0]
79
84
  # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
80
85
  if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
@@ -20,7 +20,7 @@ from tpu_inference.layers.common.quant_methods import (COMPRESSED_TENSORS,
20
20
  get_tpu_quant_method)
21
21
  from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
22
22
  from tpu_inference.layers.vllm.quantization.compressed_tensors.compressed_tensors_moe import \
23
- VllmCompressedTensorsW8A8Fp8MoEMethod
23
+ VllmCompressedTensorsMoEMethod
24
24
  from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_fp8 import \
25
25
  VllmCompressedTensorsW8A8Fp8
26
26
  from tpu_inference.layers.vllm.quantization.compressed_tensors.schemes.compressed_tensors_w8a8_int8 import \
@@ -113,8 +113,9 @@ class VllmCompressedTensorsConfig(CompressedTensorsConfig, JaxCommonConfig):
113
113
  layer.scheme = scheme
114
114
  return CompressedTensorsLinearMethod(self)
115
115
  if isinstance(layer, FusedMoE):
116
- return VllmCompressedTensorsW8A8Fp8MoEMethod(
117
- self, layer.quant_config, self.mesh)
116
+ layer.moe_config = self.get_moe_config(layer)
117
+ return VllmCompressedTensorsMoEMethod.get_moe_method(
118
+ self, layer, layer_name=prefix)
118
119
  if isinstance(layer, Attention):
119
120
  return CompressedTensorsKVCacheMethod(self)
120
121
  return None