tpu-inference 0.12.0.dev20251222__py3-none-any.whl → 0.12.0.dev20251224__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. tests/core/test_dp_scheduler.py +128 -71
  2. tests/e2e/test_data_parallel.py +176 -280
  3. tests/e2e/test_hybrid_kvcache.py +219 -0
  4. tests/e2e/test_speculative_decoding.py +26 -6
  5. tests/layers/jax/test_qwix.py +1 -1
  6. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +36 -21
  7. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +36 -21
  8. tests/layers/vllm/test_mxfp4.py +25 -10
  9. tests/layers/vllm/test_unquantized.py +61 -31
  10. tests/layers/vllm/utils.py +19 -4
  11. tests/models/common/test_model_loader.py +2 -2
  12. tests/models/jax/test_qwen2_5_vl.py +10 -11
  13. tests/runner/test_multimodal_manager.py +3 -3
  14. tests/runner/test_tpu_runner.py +67 -8
  15. tests/runner/test_tpu_runner_dp.py +66 -0
  16. tpu_inference/core/sched/dp_scheduler.py +65 -40
  17. tpu_inference/kernels/mla/v1/kernel.py +7 -26
  18. tpu_inference/layers/common/sharding.py +8 -3
  19. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +3 -3
  20. tpu_inference/layers/jax/attention/gpt_oss_attention.py +3 -3
  21. tpu_inference/layers/jax/attention/llama4_attention.py +3 -4
  22. tpu_inference/layers/jax/sample/sampling.py +1 -1
  23. tpu_inference/layers/vllm/fused_moe.py +51 -47
  24. tpu_inference/layers/vllm/quantization/common.py +14 -13
  25. tpu_inference/layers/vllm/quantization/mxfp4.py +21 -7
  26. tpu_inference/layers/vllm/quantization/unquantized.py +19 -7
  27. tpu_inference/layers/vllm/sharding.py +7 -4
  28. tpu_inference/models/common/model_loader.py +11 -14
  29. tpu_inference/models/jax/llama3.py +13 -10
  30. tpu_inference/models/jax/llama_guard_4.py +1 -1
  31. tpu_inference/models/jax/qwen2.py +3 -2
  32. tpu_inference/models/jax/qwen2_5_vl.py +4 -4
  33. tpu_inference/models/jax/utils/multi_modal_utils.py +4 -4
  34. tpu_inference/models/jax/utils/qwix/qwix_utils.py +3 -3
  35. tpu_inference/models/vllm/vllm_model_wrapper.py +5 -2
  36. tpu_inference/platforms/tpu_platform.py +7 -7
  37. tpu_inference/runner/compilation_manager.py +43 -33
  38. tpu_inference/runner/kv_cache_manager.py +1 -2
  39. tpu_inference/runner/multimodal_manager.py +1 -1
  40. tpu_inference/runner/tpu_runner.py +12 -9
  41. tpu_inference/utils.py +31 -30
  42. tpu_inference/worker/tpu_worker.py +5 -2
  43. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/METADATA +1 -1
  44. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/RECORD +47 -46
  45. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/WHEEL +0 -0
  46. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/licenses/LICENSE +0 -0
  47. {tpu_inference-0.12.0.dev20251222.dist-info → tpu_inference-0.12.0.dev20251224.dist-info}/top_level.txt +0 -0
@@ -25,9 +25,10 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
25
25
  ReplicatedLinear,
26
26
  RowParallelLinear)
27
27
 
28
+ from tpu_inference.layers.common.sharding import ShardingAxisName
28
29
  from tpu_inference.layers.vllm.linear_common import \
29
30
  get_model_matmul_fusion_assignment
30
- from tpu_inference.utils import TPU_SECOND_LAST_MINOR
31
+ from tpu_inference.utils import TPU_SECOND_LAST_MINOR, get_mesh_shape_product
31
32
 
32
33
  # yapf: enable
33
34
 
@@ -49,14 +50,18 @@ class JaxCommonLinearConfig:
49
50
  self.input_sharding = None
50
51
  self.output_sharding = None
51
52
 
53
+ self.tp_size = get_mesh_shape_product(self.mesh,
54
+ ShardingAxisName.MLP_TENSOR)
55
+
52
56
  if isinstance(layer, RowParallelLinear):
53
- self.weight_sharding = P(None, "model")
57
+ self.weight_sharding = P(None, ShardingAxisName.ATTN_HEAD)
54
58
  if self.enable_sp:
55
- self.output_sharding = P("model", None)
59
+ self.output_sharding = P(ShardingAxisName.MLP_TENSOR, None)
56
60
  elif isinstance(layer, ColumnParallelLinear):
57
- self.weight_sharding = P("model", None)
61
+ self.weight_sharding = P(ShardingAxisName.ATTN_HEAD, None)
62
+
58
63
  if self.enable_sp:
59
- self.input_sharding = P("model", None)
64
+ self.input_sharding = P(ShardingAxisName.MLP_TENSOR, None)
60
65
 
61
66
  if isinstance(layer, MergedColumnParallelLinear) or isinstance(
62
67
  layer, QKVParallelLinear):
@@ -75,18 +80,14 @@ class JaxCommonLinearConfig:
75
80
  " bad performance.", type(layer))
76
81
 
77
82
  self.bias_sharding = P(self.weight_sharding[0])
78
- if isinstance(self.weight_sharding[0], tuple):
79
- self.n_shards = 1
80
- for axis in self.weight_sharding[0]:
81
- self.n_shards *= self.mesh.shape.get(axis, 1)
82
- else:
83
- self.n_shards = self.mesh.shape.get(self.weight_sharding[0], 1)
83
+ self.n_shards = get_mesh_shape_product(self.mesh,
84
+ self.weight_sharding[0])
84
85
 
85
86
  def get_input_sharding(self, x: torchax.tensor.Tensor):
86
87
  if self.enable_sp:
87
88
  token_num = x.shape[0]
88
89
  # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
89
- if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
90
+ if token_num // self.tp_size >= TPU_SECOND_LAST_MINOR:
90
91
  return self.input_sharding
91
92
  else:
92
93
  return None
@@ -96,7 +97,7 @@ class JaxCommonLinearConfig:
96
97
  if self.enable_sp:
97
98
  token_num = x.shape[0]
98
99
  # NOTE(chengjiyao): make sure the sharded token_num is larger than TPU_SECOND_LAST_MINOR
99
- if token_num // self.mesh.shape["model"] >= TPU_SECOND_LAST_MINOR:
100
+ if token_num // self.tp_size >= TPU_SECOND_LAST_MINOR:
100
101
  return self.output_sharding
101
102
  else:
102
103
  return None
@@ -44,12 +44,14 @@ from tpu_inference.layers.common.quant_methods import (MXFP4,
44
44
  get_tpu_quant_method)
45
45
  from tpu_inference.layers.common.quantization import (
46
46
  dequantize_tensor_from_mxfp4_packed, quantize_tensor)
47
+ from tpu_inference.layers.common.sharding import ShardingAxisName
47
48
  from tpu_inference.layers.vllm.fused_moe import fused_moe_func
48
49
  from tpu_inference.layers.vllm.linear_common import \
49
50
  reorder_concatenated_tensor_for_sharding
50
51
  from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
51
52
  from tpu_inference.layers.vllm.quantization.unquantized import \
52
53
  VllmUnquantizedLinearMethod
54
+ from tpu_inference.utils import get_mesh_shape_product
53
55
 
54
56
  REQUANTIZED_BLOCK_SIZE = 512
55
57
 
@@ -256,7 +258,8 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
256
258
  w2_bias = jnp.expand_dims(w2_bias, 1)
257
259
 
258
260
  if layer.use_ep:
259
- ep_sharding = NamedSharding(self.mesh, P("model"))
261
+ ep_sharding = NamedSharding(self.mesh,
262
+ P(ShardingAxisName.EXPERT))
260
263
 
261
264
  w13_weight = jax.lax.with_sharding_constraint(
262
265
  w13_weight, ep_sharding)
@@ -275,7 +278,8 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
275
278
 
276
279
  else:
277
280
  output_sizes = [intermediate_size, intermediate_size]
278
- n_shards = self.mesh.shape["model"]
281
+ n_shards = get_mesh_shape_product(
282
+ self.mesh, ShardingAxisName.MLP_TENSOR)
279
283
  assert intermediate_size % n_shards == 0
280
284
 
281
285
  # Reorder w13 weights so that splitting w1 and w3 output
@@ -301,19 +305,29 @@ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
301
305
 
302
306
  w13_weight = jax.lax.with_sharding_constraint(
303
307
  w13_weight,
304
- NamedSharding(self.mesh, P(None, "model", None)))
308
+ NamedSharding(
309
+ self.mesh,
310
+ P(None, ShardingAxisName.MLP_TENSOR, None)))
305
311
  w2_weight = jax.lax.with_sharding_constraint(
306
312
  w2_weight,
307
- NamedSharding(self.mesh, P(None, None, "model")))
313
+ NamedSharding(
314
+ self.mesh,
315
+ P(None, None, ShardingAxisName.MLP_TENSOR)))
308
316
  w13_weight_scale = jax.lax.with_sharding_constraint(
309
317
  w13_weight_scale,
310
- NamedSharding(self.mesh, P(None, None, None, "model")))
318
+ NamedSharding(
319
+ self.mesh,
320
+ P(None, None, None, ShardingAxisName.MLP_TENSOR)))
311
321
  w2_weight_scale = jax.lax.with_sharding_constraint(
312
322
  w2_weight_scale,
313
- NamedSharding(self.mesh, P(None, "model", None, None)))
323
+ NamedSharding(
324
+ self.mesh,
325
+ P(None, ShardingAxisName.MLP_TENSOR, None, None)))
314
326
  w13_bias = jax.lax.with_sharding_constraint(
315
327
  w13_bias,
316
- NamedSharding(self.mesh, P(None, None, "model")))
328
+ NamedSharding(
329
+ self.mesh,
330
+ P(None, None, ShardingAxisName.MLP_TENSOR)))
317
331
  w2_bias = jax.lax.with_sharding_constraint(
318
332
  w2_bias, NamedSharding(self.mesh, P(None, None, None)))
319
333
 
@@ -39,12 +39,14 @@ from tpu_inference import envs
39
39
  from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
40
40
  from tpu_inference.layers.common.quant_methods import (UNQUANTIZED,
41
41
  get_tpu_quant_method)
42
+ from tpu_inference.layers.common.sharding import ShardingAxisName
42
43
  from tpu_inference.layers.vllm.fused_moe import fused_moe_func
43
44
  from tpu_inference.layers.vllm.linear_common import (
44
45
  reorder_concatenated_tensor_for_sharding,
45
46
  slice_sharded_tensor_for_concatenation, torch_to_jax_param)
46
47
  from tpu_inference.layers.vllm.quantization.common import (
47
48
  JaxCommonConfig, JaxCommonLinearConfig)
49
+ from tpu_inference.utils import get_mesh_shape_product
48
50
 
49
51
  P = PartitionSpec
50
52
  logger = init_logger(__name__)
@@ -307,7 +309,8 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
307
309
  w2_bias = jnp.expand_dims(w2_bias, 1)
308
310
 
309
311
  if layer.use_ep:
310
- ep_sharding = NamedSharding(self.mesh, P("model"))
312
+ ep_sharding = NamedSharding(self.mesh,
313
+ P(ShardingAxisName.EXPERT))
311
314
  w13_weight = jax.device_put(
312
315
  w13_weight, Format(Layout((0, 1, 2)), ep_sharding))
313
316
  w2_weight = jax.device_put(
@@ -321,19 +324,26 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
321
324
 
322
325
  else:
323
326
  output_sizes = [intermediate_size, intermediate_size]
324
- n_shards = self.mesh.shape["model"]
327
+ n_shards = get_mesh_shape_product(self.mesh,
328
+ ShardingAxisName.MLP_TENSOR)
325
329
  assert intermediate_size % n_shards == 0
326
330
 
327
331
  w13_weight = reorder_concatenated_tensor_for_sharding(
328
332
  w13_weight, output_sizes, n_shards, dim=1)
329
333
  w13_weight = jax.device_put(
330
334
  w13_weight,
331
- Format(Layout((0, 1, 2)),
332
- NamedSharding(self.mesh, P(None, "model", None))))
335
+ Format(
336
+ Layout((0, 1, 2)),
337
+ NamedSharding(
338
+ self.mesh,
339
+ P(None, ShardingAxisName.MLP_TENSOR, None))))
333
340
  w2_weight = jax.device_put(
334
341
  w2_weight,
335
- Format(Layout((0, 1, 2)),
336
- NamedSharding(self.mesh, P(None, None, "model"))))
342
+ Format(
343
+ Layout((0, 1, 2)),
344
+ NamedSharding(
345
+ self.mesh,
346
+ P(None, None, ShardingAxisName.MLP_TENSOR))))
337
347
 
338
348
  if self.moe.has_bias:
339
349
  w13_bias = reorder_concatenated_tensor_for_sharding(
@@ -343,7 +353,9 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
343
353
  w13_bias,
344
354
  Format(
345
355
  Layout((0, 1, 2)),
346
- NamedSharding(self.mesh, P(None, None, "model"))))
356
+ NamedSharding(
357
+ self.mesh,
358
+ P(None, None, ShardingAxisName.MLP_TENSOR))))
347
359
  w2_bias = jax.device_put(
348
360
  w2_bias,
349
361
  Format(Layout((0, 1, 2)),
@@ -34,6 +34,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
34
34
  ParallelLMHead, VocabParallelEmbedding)
35
35
 
36
36
  from tpu_inference import envs
37
+ from tpu_inference.layers.common.sharding import ShardingAxisName
37
38
  from tpu_inference.logger import init_logger
38
39
 
39
40
  P = PartitionSpec
@@ -123,7 +124,8 @@ def _shard_tensor_to_tpu_replicated(tensor: torch.Tensor,
123
124
  def _shard_vocab_parallel_embedding(layer: VocabParallelEmbedding,
124
125
  mesh: Mesh) -> None:
125
126
  weight = _convert_to_torchax_and_shard(
126
- layer.weight, NamedSharding(mesh, P('model', None)))
127
+ layer.weight, NamedSharding(mesh, P(ShardingAxisName.MLP_TENSOR,
128
+ None)))
127
129
  layer.weight = Parameter(weight, requires_grad=False)
128
130
 
129
131
 
@@ -132,11 +134,12 @@ def _shard_lm_head(layer: ParallelLMHead, mesh: Mesh):
132
134
  # if that config is set, then we should not create new weights but reuse the
133
135
  # weight from VocabParallelEmbedding
134
136
  weight = _convert_to_torchax_and_shard(
135
- layer.weight, NamedSharding(mesh, P('model', None)))
137
+ layer.weight, NamedSharding(mesh, P(ShardingAxisName.MLP_TENSOR,
138
+ None)))
136
139
  layer.weight = Parameter(weight, requires_grad=False)
137
140
  if layer.bias is not None:
138
- bias = _convert_to_torchax_and_shard(layer.bias,
139
- NamedSharding(mesh, P('model')))
141
+ bias = _convert_to_torchax_and_shard(
142
+ layer.bias, NamedSharding(mesh, P(ShardingAxisName.MLP_TENSOR)))
140
143
  layer.bias = Parameter(bias, requires_grad=False)
141
144
 
142
145
 
@@ -283,10 +283,9 @@ def get_flax_model(
283
283
 
284
284
  # Multi-modal support only
285
285
  # This function calculates the image token's embeddings by VIT
286
- def run_get_multimodal_embeddings(graphdef, state, image_grid_thw,
287
- **kwargs):
286
+ def run_embed_multimodal(graphdef, state, image_grid_thw, **kwargs):
288
287
  model = nnx.merge(graphdef, state)
289
- return model.get_multimodal_embeddings(image_grid_thw, **kwargs)
288
+ return model.embed_multimodal(image_grid_thw, **kwargs)
290
289
 
291
290
  embed_sharding = NamedSharding(mesh, PartitionSpec(None))
292
291
  # This function will calculates the embeddings of input texts and then merge with the image embeddings
@@ -294,9 +293,9 @@ def get_flax_model(
294
293
  jax.jit,
295
294
  out_shardings=(embed_sharding),
296
295
  )
297
- def run_get_input_embeddings(graphdef, state, *args, **kwargs):
296
+ def run_embed_input_ids(graphdef, state, *args, **kwargs):
298
297
  model = nnx.merge(graphdef, state)
299
- return model.get_input_embeddings(*args, **kwargs)
298
+ return model.embed_input_ids(*args, **kwargs)
300
299
 
301
300
  # For models that want to work with EAGLE-3 speculative decoding
302
301
  @functools.partial(
@@ -312,10 +311,8 @@ def get_flax_model(
312
311
  None)
313
312
  model_fn = functools.partial(run_model, graphdef)
314
313
  compute_logits_fn = functools.partial(run_compute_logits, graphdef)
315
- get_multimodal_embeddings_fn = functools.partial(
316
- run_get_multimodal_embeddings, graphdef)
317
- get_input_embeddings_fn = functools.partial(run_get_input_embeddings,
318
- graphdef)
314
+ embed_multimodal_fn = functools.partial(run_embed_multimodal, graphdef)
315
+ embed_input_ids_fn = functools.partial(run_embed_input_ids, graphdef)
319
316
  lora_manager, model = None, None
320
317
  combine_hidden_states_fn = functools.partial(combine_hidden_states,
321
318
  graphdef)
@@ -326,8 +323,8 @@ def get_flax_model(
326
323
 
327
324
  multimodal_fns = {
328
325
  "precompile_vision_encoder_fn": precompile_vision_encoder_fn,
329
- "get_multimodal_embeddings_fn": get_multimodal_embeddings_fn,
330
- "get_input_embeddings_fn": get_input_embeddings_fn,
326
+ "embed_multimodal_fn": embed_multimodal_fn,
327
+ "embed_input_ids_fn": embed_input_ids_fn,
331
328
  "get_mrope_input_positions_fn": get_mrope_input_positions_fn,
332
329
  }
333
330
 
@@ -485,14 +482,14 @@ def register_model(arch: str, model: Any) -> None:
485
482
  )
486
483
 
487
484
  # Same as `forward`, this is a dummy method to satisfy vLLM's type checks.
488
- def unimplemented_get_input_embeddings(
485
+ def unimplemented_embed_input_ids(
489
486
  self,
490
487
  input_ids: "torch.Tensor",
491
488
  positions: "torch.Tensor",
492
489
  inputs_embeds: Optional["torch.Tensor"] = None,
493
490
  ) -> "torch.Tensor":
494
491
  raise NotImplementedError(
495
- "This is a JAX model and does not implement the PyTorch get_input_embeddings method."
492
+ "This is a JAX model and does not implement the PyTorch embed_input_ids method."
496
493
  )
497
494
 
498
495
  # We need a custom __init__ that only calls torch.nn.Module's init,
@@ -508,7 +505,7 @@ def register_model(arch: str, model: Any) -> None:
508
505
  {
509
506
  "__init__": wrapper_init,
510
507
  "forward": unimplemented_forward,
511
- "get_input_embeddings": unimplemented_get_input_embeddings,
508
+ "embed_input_ids": unimplemented_embed_input_ids,
512
509
  # Prevent vLLM from trying to load weights into this dummy class.
513
510
  "load_weights": lambda self, *args, **kwargs: None,
514
511
  })
@@ -26,6 +26,7 @@ from tpu_inference import utils
26
26
  from tpu_inference.distributed.jax_parallel_state import get_pp_group
27
27
  from tpu_inference.layers.common.attention_interface import attention
28
28
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
29
+ from tpu_inference.layers.common.quantization import quantize_kv
29
30
  from tpu_inference.layers.common.sharding import ShardingAxisName
30
31
  from tpu_inference.layers.jax.pp_utils import PPMissingLayer, make_layers
31
32
  from tpu_inference.layers.jax.rope_interface import apply_rope
@@ -34,6 +35,7 @@ from tpu_inference.models.jax.jax_intermediate_tensor import \
34
35
  JaxIntermediateTensors
35
36
  from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
36
37
  load_hf_weights)
38
+ from tpu_inference.utils import get_mesh_shape_product
37
39
 
38
40
  logger = init_logger(__name__)
39
41
 
@@ -98,7 +100,8 @@ class LlamaAttention(nnx.Module):
98
100
  self.hidden_size // self.num_heads)
99
101
  self.head_dim = utils.get_padded_head_dim(self.head_dim_original)
100
102
 
101
- sharding_size = mesh.shape["model"] * mesh.shape.get("attn_dp", 1)
103
+ sharding_size = get_mesh_shape_product(mesh,
104
+ ShardingAxisName.MLP_TENSOR)
102
105
  self.num_heads = utils.get_padded_num_heads(self.num_heads,
103
106
  sharding_size)
104
107
  self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads,
@@ -171,8 +174,8 @@ class LlamaAttention(nnx.Module):
171
174
  # q_scale = self._q_scale
172
175
  k_scale = self._k_scale
173
176
  v_scale = self._v_scale
174
- k, v = utils.quantize_kv(k, v, self.kv_cache_quantized_dtype,
175
- k_scale, v_scale)
177
+ k, v = quantize_kv(self.kv_cache_quantized_dtype, k, v, k_scale,
178
+ v_scale)
176
179
  new_kv_cache, outputs = attention(
177
180
  kv_cache,
178
181
  q,
@@ -369,13 +372,13 @@ class LlamaForCausalLM(nnx.Module):
369
372
  kv_caches: List[jax.Array],
370
373
  input_ids: jax.Array,
371
374
  attention_metadata: AttentionMetadata,
372
- _input_embeds,
373
- _input_positions,
374
- _layer_name_to_kv_cache,
375
- _lora_metadata,
376
- intermediate_tensors: JaxIntermediateTensors,
377
- _is_first_rank: bool,
378
- _is_last_rank: bool,
375
+ _input_embeds=None,
376
+ _input_positions=None,
377
+ _layer_name_to_kv_cache=None,
378
+ _lora_metadata=None,
379
+ intermediate_tensors: JaxIntermediateTensors | None = None,
380
+ _is_first_rank: bool | None = None,
381
+ _is_last_rank: bool | None = None,
379
382
  *args,
380
383
  ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]] | Tuple[
381
384
  List[jax.Array], JaxIntermediateTensors]:
@@ -256,7 +256,7 @@ class LlamaGuard4ForCausalLM(nnx.Module):
256
256
  self.lm_head.input_embedding_table_DV.value)
257
257
  return logits_TV
258
258
 
259
- def get_input_embeddings(
259
+ def embed_input_ids(
260
260
  self,
261
261
  input_ids: jax.Array,
262
262
  multimodal_embeddings: Optional[List[jax.Array]] = None
@@ -24,6 +24,7 @@ from vllm.config import VllmConfig
24
24
  from tpu_inference import utils
25
25
  from tpu_inference.layers.common.attention_interface import attention
26
26
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
27
+ from tpu_inference.layers.common.quantization import quantize_kv
27
28
  from tpu_inference.layers.jax.rope_interface import apply_rope
28
29
  from tpu_inference.logger import init_logger
29
30
  from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
@@ -166,8 +167,8 @@ class Qwen2Attention(nnx.Module):
166
167
  # q_scale = self._q_scale
167
168
  k_scale = self._k_scale
168
169
  v_scale = self._v_scale
169
- k, v = utils.quantize_kv(k, v, self.kv_cache_quantized_dtype,
170
- k_scale, v_scale)
170
+ k, v = quantize_kv(self.kv_cache_quantized_dtype, k, v, k_scale,
171
+ v_scale)
171
172
  new_kv_cache, outputs = attention(
172
173
  kv_cache,
173
174
  q,
@@ -1010,9 +1010,9 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
1010
1010
  split_indices = np.cumsum(sizes)[:-1]
1011
1011
  return tuple(jnp.split(image_embeds, split_indices))
1012
1012
 
1013
- def get_multimodal_embeddings(self, image_grid_thw: tuple[tuple[int, int,
1014
- int], ...],
1015
- **kwargs: object) -> MultiModalEmbeddings:
1013
+ def embed_multimodal(self, image_grid_thw: tuple[tuple[int, int, int],
1014
+ ...],
1015
+ **kwargs: object) -> MultiModalEmbeddings:
1016
1016
 
1017
1017
  mm_input_by_modality = self._parse_and_validate_multimodal_inputs(
1018
1018
  image_grid_thw, **kwargs)
@@ -1036,7 +1036,7 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
1036
1036
 
1037
1037
  return multimodal_embeddings
1038
1038
 
1039
- def get_input_embeddings(
1039
+ def embed_input_ids(
1040
1040
  self, input_ids: jax.Array,
1041
1041
  multimodal_embeddings: Optional[jax.Array]) -> jax.Array:
1042
1042
 
@@ -43,25 +43,25 @@ def sanity_check_mm_encoder_outputs(
43
43
  ) -> None:
44
44
  """
45
45
  Perform sanity checks for the result of
46
- [`vllm.model_executor.models.SupportsMultiModal.get_multimodal_embeddings`][].
46
+ [`vllm.model_executor.models.SupportsMultiModal.embed_multimodal`][].
47
47
  """
48
48
  assert isinstance(mm_embeddings, (list, tuple, jax.Array)), (
49
49
  "Expected multimodal embeddings to be a list/tuple of 2D tensors, "
50
50
  f"or a single 3D tensor, but got {type(mm_embeddings)} "
51
51
  "instead. This is most likely due to incorrect implementation "
52
- "of the model's `get_multimodal_embeddings` method.")
52
+ "of the model's `embed_multimodal` method.")
53
53
 
54
54
  assert len(mm_embeddings) == expected_num_items, (
55
55
  "Expected number of multimodal embeddings to match number of "
56
56
  f"input items: {expected_num_items}, but got {len(mm_embeddings)=} "
57
57
  "instead. This is most likely due to incorrect implementation "
58
- "of the model's `get_multimodal_embeddings` method.")
58
+ "of the model's `embed_multimodal` method.")
59
59
 
60
60
  assert all(e.ndim == 2 for e in mm_embeddings), (
61
61
  "Expected multimodal embeddings to be a sequence of 2D tensors, "
62
62
  f"but got tensors with shapes {[e.shape for e in mm_embeddings]} "
63
63
  "instead. This is most likely due to incorrect implementation "
64
- "of the model's `get_multimodal_embeddings` method.")
64
+ "of the model's `embed_multimodal` method.")
65
65
 
66
66
 
67
67
  def flatten_embeddings(embeddings: NestedTensors) -> jax.Array:
@@ -35,7 +35,7 @@ DEFAULT_NUM_TOKENS_FOR_MODEL_INPUTS = 512
35
35
  DEFAULT_MAX_NUM_SEQS_FOR_MODEL_INPUTS = 256
36
36
  DEFAULT_MAX_NUM_BLOCKS_PER_REQ = 16
37
37
 
38
- DEFAULT_DEEPSEEK_FP8_CONFIG = {
38
+ DEFAULT_DEEPSEEK_FP4_MLP_MOE_FP8_ATTN_CONFIG = {
39
39
  "qwix": {
40
40
  "use_abstract_model":
41
41
  True,
@@ -452,7 +452,7 @@ def get_default_qwix_quantization_config(
452
452
  # NOTE (jacobplatin): we'll default to mixed FP8 (attention) + FP4 (MoE experts)
453
453
  # for DeepSeek
454
454
  if model_type == "deepseek_v3" and quant_method == "fp8":
455
- config = copy.deepcopy(DEFAULT_DEEPSEEK_FP8_CONFIG)
455
+ config = copy.deepcopy(DEFAULT_DEEPSEEK_FP4_MLP_MOE_FP8_ATTN_CONFIG)
456
456
 
457
457
  # Dynamically fetch block size from HF config if available
458
458
  # Config fmt: 'weight_block_size': [1, 512] -> we want the 2nd dim for tile_size
@@ -462,7 +462,7 @@ def get_default_qwix_quantization_config(
462
462
  block_size = hf_quant_config["weight_block_size"]
463
463
  if isinstance(block_size, (list, tuple)) and len(block_size) == 2:
464
464
  assert block_size[
465
- 0] == 1, f"Expected first dimension to be 1 (unchanneled), but got {block_size[0]}!"
465
+ 0] == 1, f"Expected first dimension to be 1 (unchanneled), but got {block_size[0]}! If you are trying to run quantized DeepSeek, we currently only support 1D-subchannel quantization and those models can be found here: https://huggingface.co/collections/jrplatin/deepseek-r1-1d-subchannel"
466
466
  tile_size = block_size[1]
467
467
  assert tile_size > 1, f"Expected tile_size > 1 for DeepSeek, but got {tile_size}"
468
468
  logger.info(
@@ -37,6 +37,7 @@ from vllm.model_executor.models import supports_lora, supports_multimodal
37
37
  from vllm.sequence import IntermediateTensors
38
38
 
39
39
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
40
+ from tpu_inference.layers.common.sharding import ShardingAxisName
40
41
  from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
41
42
  from tpu_inference.layers.vllm.sharding import shard_model_to_tpu
42
43
  from tpu_inference.logger import init_logger
@@ -234,8 +235,10 @@ class VllmModelWrapper:
234
235
 
235
236
  @functools.partial(
236
237
  jax.jit,
237
- out_shardings=(NamedSharding(self.mesh,
238
- PartitionSpec("data", "model"))),
238
+ out_shardings=(NamedSharding(
239
+ self.mesh,
240
+ PartitionSpec(ShardingAxisName.MLP_DATA,
241
+ ShardingAxisName.MLP_TENSOR))),
239
242
  )
240
243
  def compute_logits_func(
241
244
  params_and_buffers: Any,
@@ -168,12 +168,12 @@ class TpuPlatform(Platform):
168
168
  multihost_backend = envs.TPU_MULTIHOST_BACKEND
169
169
  if not multihost_backend: # Single host
170
170
  if parallel_config.pipeline_parallel_size == 1:
171
- logger.info("Force using UniProcExecutor for JAX on \
172
- single host without pipeline parallelism.")
171
+ logger.info("Force using UniProcExecutor for JAX on "
172
+ "single host without pipeline parallelism.")
173
173
  parallel_config.distributed_executor_backend = "uni"
174
174
  else:
175
- logger.info("Force using MultiprocExecutor for JAX on \
176
- single host with pipeline parallelism.")
175
+ logger.info("Force using MultiprocExecutor for JAX on "
176
+ "single host with pipeline parallelism.")
177
177
  parallel_config.distributed_executor_backend = "mp"
178
178
  elif multihost_backend == "ray":
179
179
  from tpu_inference.executors.ray_distributed_executor import \
@@ -189,9 +189,9 @@ class TpuPlatform(Platform):
189
189
 
190
190
  if scheduler_config.is_multimodal_model and not \
191
191
  scheduler_config.disable_chunked_mm_input:
192
- logger.warning("TPU does not support running Multimodal models"\
193
- " without setting `--disable_chunked_mm_input`. " \
194
- "Forcing --disable_chunked_mm_input.")
192
+ logger.warning("TPU does not support running Multimodal models"
193
+ " without setting `--disable_chunked_mm_input`. "
194
+ "Forcing --disable_chunked_mm_input.")
195
195
  scheduler_config.disable_chunked_mm_input = True
196
196
 
197
197
  kv_transfer_config = vllm_config.kv_transfer_config
@@ -127,7 +127,7 @@ class CompilationManager:
127
127
 
128
128
  self._run_compilation(
129
129
  "input_embeddings_merger",
130
- self.runner.get_input_embeddings_fn,
130
+ self.runner.embed_input_ids_fn,
131
131
  self.runner.state,
132
132
  dummy_input_ids,
133
133
  dummy_multimodal_embeddings,
@@ -136,7 +136,7 @@ class CompilationManager:
136
136
 
137
137
  self._run_compilation(
138
138
  "input_embeddings_merger_text_only",
139
- self.runner.get_input_embeddings_fn,
139
+ self.runner.embed_input_ids_fn,
140
140
  self.runner.state,
141
141
  dummy_input_ids,
142
142
  None,
@@ -495,35 +495,37 @@ class CompilationManager:
495
495
  logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
496
496
  logits_sharding)
497
497
  for do_sampling in (True, False):
498
- if do_sampling:
499
- temperature = np.full((num_reqs, ), 0.7, dtype=np.float32)
500
- top_k = np.full((num_reqs, ), 20, dtype=np.int32)
501
- top_p = np.full((num_reqs, ), 0.8, dtype=np.float32)
502
- (temperature, top_k,
503
- top_p) = device_array(self.runner.mesh,
504
- (temperature, top_k, top_p),
505
- sharding=sampling_metadata_sharding)
506
- else:
507
- temperature = None
508
- top_k = None
509
- top_p = None
510
-
511
- sampling_metadata = TPUSupportedSamplingMetadata(
512
- temperature=temperature,
513
- top_k=top_k,
514
- top_p=top_p,
515
- do_sampling=do_sampling,
516
- )
517
- self._run_compilation(
518
- f"worker{self.runner.rank} sample",
519
- sample,
520
- self.runner.rng_params_for_sampling,
521
- self.runner.mesh,
522
- logits,
523
- sampling_metadata,
524
- num_reqs=num_reqs,
525
- do_sampling=do_sampling,
526
- )
498
+ for logprobs in (True, False):
499
+ if do_sampling:
500
+ temperature = np.full((num_reqs, ),
501
+ 0.7,
502
+ dtype=np.float32)
503
+ top_k = np.full((num_reqs, ), 20, dtype=np.int32)
504
+ top_p = np.full((num_reqs, ), 0.8, dtype=np.float32)
505
+ (temperature, top_k, top_p) = device_array(
506
+ self.runner.mesh, (temperature, top_k, top_p),
507
+ sharding=sampling_metadata_sharding)
508
+ else:
509
+ temperature = None
510
+ top_k = None
511
+ top_p = None
512
+
513
+ sampling_metadata = TPUSupportedSamplingMetadata(
514
+ temperature=temperature,
515
+ top_k=top_k,
516
+ top_p=top_p,
517
+ do_sampling=do_sampling,
518
+ logprobs=logprobs)
519
+ self._run_compilation(
520
+ f"worker{self.runner.rank} sample",
521
+ sample,
522
+ self.runner.rng_params_for_sampling,
523
+ self.runner.mesh,
524
+ logits,
525
+ sampling_metadata,
526
+ num_reqs=num_reqs,
527
+ do_sampling=do_sampling,
528
+ )
527
529
 
528
530
  self._sampling_precompiled = True
529
531
 
@@ -555,8 +557,16 @@ class CompilationManager:
555
557
  logger.info("Compiling gather_logprobs with different input shapes.")
556
558
  hsize = self.runner.model_config.get_vocab_size()
557
559
  for num_reqs in self.runner.num_reqs_paddings:
558
- logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16)
559
- token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32)
560
+ logits_sharding = NamedSharding(
561
+ self.runner.mesh,
562
+ PartitionSpec(ShardingAxisName.MLP_DATA,
563
+ ShardingAxisName.MLP_TENSOR))
564
+ token_ids_sharding = NamedSharding(
565
+ self.runner.mesh, PartitionSpec(ShardingAxisName.MLP_DATA, ))
566
+ logits = self._create_dummy_tensor((num_reqs, hsize), jnp.bfloat16,
567
+ logits_sharding)
568
+ token_ids = self._create_dummy_tensor((num_reqs, ), jnp.int32,
569
+ token_ids_sharding)
560
570
  self._run_compilation(
561
571
  f"worker{self.runner.rank} gather_logprobs",
562
572
  self.runner._compute_and_gather_logprobs,