tpu-inference 0.11.1.dev202511130813__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (37) hide show
  1. tests/test_envs.py +182 -0
  2. tests/test_utils.py +23 -14
  3. tpu_inference/core/core_tpu.py +17 -9
  4. tpu_inference/executors/ray_distributed_executor.py +24 -11
  5. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +33 -10
  6. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +7 -0
  7. tpu_inference/layers/{jax → common}/attention_interface.py +1 -1
  8. tpu_inference/layers/common/quant_methods.py +8 -0
  9. tpu_inference/layers/jax/attention/attention.py +1 -1
  10. tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
  11. tpu_inference/layers/jax/sample/sampling.py +2 -2
  12. tpu_inference/layers/vllm/attention.py +1 -1
  13. tpu_inference/layers/vllm/quantization/__init__.py +7 -3
  14. tpu_inference/layers/vllm/quantization/awq.py +4 -3
  15. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -2
  16. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  17. tpu_inference/layers/vllm/quantization/unquantized.py +4 -3
  18. tpu_inference/models/common/model_loader.py +3 -2
  19. tpu_inference/models/jax/llama3.py +2 -2
  20. tpu_inference/models/jax/phi3.py +1 -1
  21. tpu_inference/models/jax/qwen2.py +1 -1
  22. tpu_inference/models/jax/qwen2_5_vl.py +2 -2
  23. tpu_inference/models/jax/qwen3.py +1 -1
  24. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -10
  25. tpu_inference/platforms/tpu_platform.py +12 -5
  26. tpu_inference/runner/compilation_manager.py +4 -2
  27. tpu_inference/runner/kv_cache.py +1 -1
  28. tpu_inference/runner/tpu_runner.py +31 -7
  29. tpu_inference/utils.py +2 -2
  30. tpu_inference/worker/tpu_worker.py +1 -1
  31. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +1 -1
  32. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +37 -34
  33. /tpu_inference/layers/{jax → common}/binary_search.py +0 -0
  34. /tpu_inference/layers/{jax → common}/sharding.py +0 -0
  35. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
  36. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
  37. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,266 @@
1
+ from typing import Callable, Optional, Union
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ import torch
6
+ from jax.experimental.layout import Format, Layout
7
+ from jax.sharding import Mesh, NamedSharding, PartitionSpec
8
+ from torch.nn.parameter import Parameter
9
+ from torchax.interop import jax_view, torch_view
10
+ from torchax.ops.mappings import t2j
11
+ from vllm.logger import init_logger
12
+ from vllm.model_executor.layers.fused_moe.config import (
13
+ FusedMoEConfig, FusedMoEQuantConfig, biased_moe_quant_config)
14
+ from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
15
+ FusedMoEMethodBase)
16
+ from vllm.model_executor.layers.linear import LinearBase
17
+ from vllm.model_executor.layers.quantization import \
18
+ register_quantization_config
19
+ from vllm.model_executor.layers.quantization.base_config import \
20
+ QuantizeMethodBase
21
+ from vllm.model_executor.layers.quantization.mxfp4 import (Mxfp4Backend,
22
+ Mxfp4Config,
23
+ Mxfp4MoEMethod)
24
+ from vllm.model_executor.layers.quantization.utils.quant_utils import \
25
+ is_layer_skipped
26
+
27
+ from tpu_inference.layers.common.quant_methods import (MXFP4,
28
+ get_tpu_quant_method)
29
+ from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
30
+ from tpu_inference.layers.vllm.linear_common import \
31
+ reorder_concatenated_tensor_for_sharding
32
+ from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
33
+ from tpu_inference.layers.vllm.quantization.unquantized import \
34
+ VllmUnquantizedLinearMethod
35
+
36
+ MXFP4_BLOCK_SIZE = 32
37
+
38
+ P = PartitionSpec
39
+ logger = init_logger(__name__)
40
+
41
+
42
+ # TODO(kyuyeunk): Move these functions into a common utility file.
43
+ def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
44
+ assert u8_packed_e2m1.dtype == jnp.uint8
45
+ e2m1 = jax.lax.bitcast_convert_type(u8_packed_e2m1, jnp.float4_e2m1fn)
46
+ # bitcast creates one more dimension that splits 8 bits into two e2m1.
47
+ # we flatten them with the last dim.
48
+ return jnp.reshape(e2m1, e2m1.shape[:-2] + (-1, ))
49
+
50
+
51
+ def e8m0_to_fp32(u8: jax.Array) -> jax.Array:
52
+ e8_finfo = jnp.finfo(jnp.float8_e8m0fnu)
53
+ exponents = u8.astype(jnp.int32) + e8_finfo.minexp
54
+ ones = jnp.ones_like(u8, dtype=jnp.float32)
55
+ return jnp.ldexp(ones, exponents)
56
+
57
+
58
+ def dequantize_block_weight(weight: jax.Array,
59
+ scale: jax.Array,
60
+ block_size: int,
61
+ out_dtype: jnp.dtype = jnp.bfloat16) -> jax.Array:
62
+ orig_shape = weight.shape
63
+ weight_block = weight.reshape(orig_shape[:-1] + (-1, block_size))
64
+ weight_dequantized = weight_block.astype(jnp.float32) * jnp.expand_dims(
65
+ scale, -1)
66
+ return weight_dequantized.reshape(orig_shape).astype(out_dtype)
67
+
68
+
69
+ @register_quantization_config(get_tpu_quant_method(MXFP4))
70
+ class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
71
+
72
+ @classmethod
73
+ def get_name(cls):
74
+ return MXFP4
75
+
76
+ def get_quant_method(self, layer: torch.nn.Module,
77
+ prefix: str) -> Optional["QuantizeMethodBase"]:
78
+ from vllm.attention.layer import Attention # Avoid circular import
79
+
80
+ if isinstance(layer, LinearBase):
81
+ linear_config = self.get_linear_config(layer)
82
+ if self.ignored_layers and is_layer_skipped(
83
+ prefix=prefix,
84
+ ignored_layers=self.ignored_layers,
85
+ fused_mapping=self.packed_modules_mapping,
86
+ ):
87
+ return VllmUnquantizedLinearMethod(linear_config)
88
+ # TODO: Add support for MXFP4 Linear Method.
89
+ # MXFP4 LinearMethod is available in AMD-Quark, refer to that
90
+ # implementation if you are interested in enabling MXFP4 here.
91
+ logger.warning_once(
92
+ "MXFP4 linear layer is not implemented - falling back to "
93
+ "UnquantizedLinearMethod.")
94
+ return VllmUnquantizedLinearMethod(linear_config)
95
+ elif isinstance(layer, FusedMoE):
96
+ return VllmMxfp4MoEMethod(layer.moe_config, self.mesh)
97
+ elif isinstance(layer, Attention):
98
+ # TODO: Add support for MXFP4 Attention.
99
+ logger.warning_once("MXFP4 attention layer is not implemented. "
100
+ "Skipping quantization for this layer.")
101
+ return None
102
+
103
+
104
+ class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
105
+
106
+ def __init__(self, moe: FusedMoEConfig, mesh: Mesh):
107
+ FusedMoEMethodBase.__init__(self, moe)
108
+
109
+ # We piggyback on triton implementation as it applies minimal hardware
110
+ # specific post processing to the weights.
111
+ self.mxfp4_backend = Mxfp4Backend.TRITON
112
+ self.mesh = mesh
113
+
114
+ def get_fused_moe_quant_config(
115
+ self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
116
+ # Because we have dequantized weights, we only need biased moe config.
117
+ # TODO(kyuyeunk): Add native support for MXFP4.
118
+ return biased_moe_quant_config(
119
+ layer.w13_bias,
120
+ layer.w2_bias,
121
+ )
122
+
123
+ def process_weights_after_loading(self, layer: torch.nn.Module):
124
+ assert isinstance(layer, FusedMoE)
125
+
126
+ w13_weight = u8_unpack_e2m1(t2j(layer.w13_weight, use_dlpack=False))
127
+ w13_weight_scale = e8m0_to_fp32(
128
+ t2j(layer.w13_weight_scale, use_dlpack=False))
129
+ w13_bias = t2j(layer.w13_bias, use_dlpack=False)
130
+
131
+ w2_weight = u8_unpack_e2m1(t2j(layer.w2_weight, use_dlpack=False))
132
+ w2_weight_scale = e8m0_to_fp32(
133
+ t2j(layer.w2_weight_scale, use_dlpack=False))
134
+ w2_bias = t2j(layer.w2_bias, use_dlpack=False)
135
+
136
+ # We dequantize fp4 weights into bf16.
137
+ # TODO(kyuyeunk): Add native support for MXFP4.
138
+ w13_weight = dequantize_block_weight(w13_weight, w13_weight_scale,
139
+ MXFP4_BLOCK_SIZE, jnp.bfloat16)
140
+ w2_weight = dequantize_block_weight(w2_weight, w2_weight_scale,
141
+ MXFP4_BLOCK_SIZE, jnp.bfloat16)
142
+
143
+ # Because we have dequantized weights, scales are not used anymore.
144
+ delattr(layer, "w13_weight_scale")
145
+ delattr(layer, "w2_weight_scale")
146
+
147
+ if layer.activation == "swigluoai":
148
+ # When using swigluoai, vLLM splits gmm output in a interleaved way.
149
+ # However, interleaved split is not performant on TPU. Therefore,
150
+ # we preprocess the weight so that splitting gmm output by middle
151
+ # can still get the same result.
152
+ w1_weight = w13_weight[:, ::2, :]
153
+ w3_weight = w13_weight[:, 1::2, :]
154
+ w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
155
+
156
+ w1_bias = w13_bias[:, ::2]
157
+ w3_bias = w13_bias[:, 1::2]
158
+ w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
159
+
160
+ # TODO(kyuyeunk): Add weight processing logic for the new kernel.
161
+ if layer.use_ep:
162
+ w13_weight = jax.device_put(
163
+ w13_weight,
164
+ Format(Layout((0, 1, 2)),
165
+ NamedSharding(self.mesh, P("model", None, None))))
166
+ w2_weight = jax.device_put(
167
+ w2_weight,
168
+ Format(Layout((0, 1, 2)),
169
+ NamedSharding(self.mesh, P("model", None, None))))
170
+
171
+ w13_bias = jax.device_put(
172
+ w13_bias,
173
+ Format(Layout((0, 1)),
174
+ NamedSharding(self.mesh, P("model", None))))
175
+ w2_bias = jax.device_put(
176
+ w2_bias,
177
+ Format(Layout((0, 1)),
178
+ NamedSharding(self.mesh, P("model", None))))
179
+
180
+ else:
181
+ intermediate_size = w13_weight.shape[1] // 2
182
+ assert intermediate_size == w2_weight.shape[-1]
183
+ output_sizes = [intermediate_size, intermediate_size]
184
+ n_shards = self.mesh.shape["model"]
185
+ assert intermediate_size % n_shards == 0
186
+ w13_weight = reorder_concatenated_tensor_for_sharding(w13_weight,
187
+ output_sizes,
188
+ n_shards,
189
+ dim=1)
190
+ w13_weight = jax.device_put(
191
+ w13_weight,
192
+ Format(Layout((0, 1, 2)),
193
+ NamedSharding(self.mesh, P(None, "model", None))))
194
+ w2_weight = jax.device_put(
195
+ w2_weight,
196
+ Format(Layout((0, 1, 2)),
197
+ NamedSharding(self.mesh, P(None, None, "model"))))
198
+
199
+ w13_bias = reorder_concatenated_tensor_for_sharding(w13_bias,
200
+ output_sizes,
201
+ n_shards,
202
+ dim=1)
203
+ w13_bias = jax.device_put(
204
+ w13_bias,
205
+ Format(Layout((0, 1)),
206
+ NamedSharding(self.mesh, P(None, "model"))))
207
+ w2_bias = jax.device_put(
208
+ w2_bias,
209
+ Format(Layout((0, 1)), NamedSharding(self.mesh, P(None,
210
+ None))))
211
+
212
+ layer.w13_weight = Parameter(torch_view(w13_weight),
213
+ requires_grad=False)
214
+ layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
215
+
216
+ layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
217
+ layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
218
+
219
+ pass
220
+
221
+ def apply(
222
+ self,
223
+ layer: torch.nn.Module,
224
+ x: torch.Tensor,
225
+ router_logits: torch.Tensor,
226
+ top_k: int,
227
+ renormalize: bool,
228
+ use_grouped_topk: bool = False,
229
+ topk_group: Optional[int] = None,
230
+ num_expert_group: Optional[int] = None,
231
+ global_num_experts: int = -1,
232
+ expert_map: Optional[torch.Tensor] = None,
233
+ custom_routing_function: Optional[Callable] = None,
234
+ scoring_func: str = "softmax",
235
+ routed_scaling_factor: float = 1.0,
236
+ e_score_correction_bias: Optional[torch.Tensor] = None,
237
+ apply_router_weight_on_input: bool = False,
238
+ activation: str = "silu",
239
+ enable_eplb: bool = False,
240
+ expert_load_view: Optional[torch.Tensor] = None,
241
+ logical_to_physical_map: Optional[torch.Tensor] = None,
242
+ logical_replica_count: Optional[torch.Tensor] = None,
243
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
244
+ assert isinstance(layer, FusedMoE)
245
+ if scoring_func != "softmax":
246
+ raise NotImplementedError(
247
+ "Only softmax is supported for scoring_func")
248
+
249
+ # Use the original implementation
250
+ output = fused_moe_func_padded(
251
+ jax_view(x),
252
+ jax_view(layer.w13_weight),
253
+ jax_view(layer.w2_weight),
254
+ jax_view(layer.w13_bias) if self.moe.has_bias else None,
255
+ jax_view(layer.w2_bias) if self.moe.has_bias else None,
256
+ jax_view(router_logits),
257
+ topk=top_k,
258
+ global_num_experts=global_num_experts,
259
+ renormalize=renormalize,
260
+ reduce_results=layer.reduce_results,
261
+ mesh=self.mesh,
262
+ use_ep=layer.use_ep,
263
+ activation=activation,
264
+ )
265
+
266
+ return torch_view(output)
@@ -23,6 +23,8 @@ from vllm.model_executor.layers.quantization.base_config import (
23
23
 
24
24
  from tpu_inference import envs
25
25
  from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
26
+ from tpu_inference.layers.common.quant_methods import (UNQUANTIZED,
27
+ get_tpu_quant_method)
26
28
  from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
27
29
  from tpu_inference.layers.vllm.linear_common import (
28
30
  reorder_concatenated_tensor_for_sharding,
@@ -34,12 +36,12 @@ P = PartitionSpec
34
36
  logger = init_logger(__name__)
35
37
 
36
38
 
37
- @register_quantization_config("jax-unquantized")
39
+ @register_quantization_config(get_tpu_quant_method(UNQUANTIZED))
38
40
  class VllmUnquantizedConfig(QuantizationConfig, JaxCommonConfig):
39
41
 
40
42
  @classmethod
41
43
  def get_name(cls) -> str:
42
- return "jax-unquantized"
44
+ return UNQUANTIZED
43
45
 
44
46
  @classmethod
45
47
  def get_supported_act_dtypes(cls) -> list[torch.dtype]:
@@ -189,7 +191,6 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
189
191
 
190
192
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
191
193
  assert isinstance(layer, FusedMoE)
192
-
193
194
  w13_weight = t2j(layer.w13_weight, use_dlpack=False)
194
195
  w2_weight = t2j(layer.w2_weight, use_dlpack=False)
195
196
 
@@ -11,7 +11,7 @@ from vllm.config import VllmConfig
11
11
  from vllm.utils.func_utils import supports_kw
12
12
 
13
13
  from tpu_inference import envs
14
- from tpu_inference.layers.jax.sharding import ShardingAxisName
14
+ from tpu_inference.layers.common.sharding import ShardingAxisName
15
15
  from tpu_inference.logger import init_logger
16
16
  from tpu_inference.models.jax.utils.quantization.quantization_utils import (
17
17
  apply_qwix_on_abstract_model, apply_qwix_quantization,
@@ -242,10 +242,11 @@ def get_flax_model(
242
242
  model = nnx.merge(graphdef, state)
243
243
  return model.get_multimodal_embeddings(image_grid_thw, **kwargs)
244
244
 
245
+ embed_sharding = NamedSharding(mesh, PartitionSpec(None))
245
246
  # This function will calculates the embeddings of input texts and then merge with the image embeddings
246
247
  @functools.partial(
247
248
  jax.jit,
248
- out_shardings=(logits_sharding),
249
+ out_shardings=(embed_sharding),
249
250
  )
250
251
  def run_get_input_embeddings(graphdef, state, *args, **kwargs):
251
252
  model = nnx.merge(graphdef, state)
@@ -8,10 +8,10 @@ from transformers import LlamaConfig, modeling_flax_utils
8
8
  from vllm.config import VllmConfig
9
9
 
10
10
  from tpu_inference import utils
11
+ from tpu_inference.layers.common.attention_interface import attention
11
12
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
12
- from tpu_inference.layers.jax.attention_interface import attention
13
+ from tpu_inference.layers.common.sharding import ShardingAxisName
13
14
  from tpu_inference.layers.jax.rope_interface import apply_rope
14
- from tpu_inference.layers.jax.sharding import ShardingAxisName
15
15
  from tpu_inference.logger import init_logger
16
16
  from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
17
17
  load_hf_weights)
@@ -8,8 +8,8 @@ from transformers import Phi3Config, modeling_flax_utils
8
8
  from vllm.config import VllmConfig
9
9
 
10
10
  from tpu_inference import utils
11
+ from tpu_inference.layers.common.attention_interface import attention
11
12
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
12
- from tpu_inference.layers.jax.attention_interface import attention
13
13
  from tpu_inference.layers.jax.rope_interface import apply_longrope, apply_rope
14
14
  from tpu_inference.logger import init_logger
15
15
  from tpu_inference.models.jax.utils.weight_utils import (MetadataMap,
@@ -8,8 +8,8 @@ from transformers import Qwen2Config, modeling_flax_utils
8
8
  from vllm.config import VllmConfig
9
9
 
10
10
  from tpu_inference import utils
11
+ from tpu_inference.layers.common.attention_interface import attention
11
12
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
12
- from tpu_inference.layers.jax.attention_interface import attention
13
13
  from tpu_inference.layers.jax.rope_interface import apply_rope
14
14
  from tpu_inference.logger import init_logger
15
15
  from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
@@ -14,9 +14,9 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
14
14
  from vllm.config import VllmConfig
15
15
 
16
16
  from tpu_inference import utils as utils
17
- from tpu_inference.layers.common.attention_metadata import AttentionMetadata
18
- from tpu_inference.layers.jax.attention_interface import \
17
+ from tpu_inference.layers.common.attention_interface import \
19
18
  sharded_flash_attention
19
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
20
20
  from tpu_inference.logger import init_logger
21
21
  from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
22
22
  # from vllm.model_executor.models.interfaces import MultiModalEmbeddings
@@ -8,8 +8,8 @@ from transformers import Qwen3Config
8
8
  from vllm.config import VllmConfig
9
9
 
10
10
  from tpu_inference import utils
11
+ from tpu_inference.layers.common.attention_interface import attention
11
12
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
12
- from tpu_inference.layers.jax.attention_interface import attention
13
13
  from tpu_inference.layers.jax.rope_interface import apply_rope
14
14
  from tpu_inference.logger import init_logger
15
15
  from tpu_inference.models.jax.qwen2 import Qwen2DecoderLayer
@@ -25,6 +25,8 @@ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
25
25
  from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
26
26
  from tpu_inference.layers.vllm.sharding import shard_model_to_tpu
27
27
  from tpu_inference.logger import init_logger
28
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
29
+ JaxIntermediateTensors
28
30
  from tpu_inference.models.vllm.vllm_model_wrapper_context import (
29
31
  get_vllm_model_wrapper_context, set_vllm_model_wrapper_context)
30
32
  from tpu_inference.runner.lora_utils import replace_lora_metadata
@@ -89,13 +91,14 @@ class VllmModelWrapper:
89
91
  slice_config = self.vllm_config.device_config.slice
90
92
  modified_slice_config = True
91
93
  self.vllm_config.device_config.slice = None
94
+ self.vllm_config.compilation_config.static_forward_context.clear()
95
+
92
96
  vllm_config_for_load = copy.deepcopy(self.vllm_config)
93
97
  if modified_slice_config:
94
98
  self.vllm_config.device_config.slice = slice_config
95
99
  assert self.vllm_config.model_config.dtype in TORCH_DTYPE_TO_JAX, "The model_config.dtype must be a PyTorch dtype."
96
100
  vllm_config_for_load.device_config.device = "cpu"
97
101
  # Clearing the cached compilation config, otherwise vllm model init will fail
98
- vllm_config_for_load.compilation_config.static_forward_context.clear()
99
102
 
100
103
  # When expert parallelism is enabled, vLLM loads weight in sharding
101
104
  # aware manner. Since tpu-inference has its own sharding logic, this
@@ -117,7 +120,8 @@ class VllmModelWrapper:
117
120
 
118
121
  # Load the vLLM model and wrap it into a new model whose forward
119
122
  # function can calculate the hidden_state and logits.
120
- with load_context:
123
+ available_devices = self.mesh.devices.flatten()
124
+ with load_context, jax.default_device(available_devices[0]):
121
125
  vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
122
126
  lora_manager = None
123
127
  if vllm_config_for_load.lora_config is not None:
@@ -149,7 +153,8 @@ class VllmModelWrapper:
149
153
  "xla_tpu_reduce_scatter_collective_matmul_mode":
150
154
  "post_spmd_conservative"
151
155
  },
152
- static_argnames=("layer_name_to_kvcache_index", ),
156
+ static_argnames=("layer_name_to_kvcache_index", "is_first_rank",
157
+ "is_last_rank"),
153
158
  )
154
159
  def step_fun(
155
160
  params_and_buffers, # This has been wrapped into torchax TorchValue
@@ -159,6 +164,9 @@ class VllmModelWrapper:
159
164
  input_embeds: jax.Array,
160
165
  layer_name_to_kvcache_index: Sequence[Tuple[str, int]],
161
166
  lora_metadata,
167
+ intermediate_tensors: JaxIntermediateTensors = None,
168
+ is_first_rank: bool = True,
169
+ is_last_rank: bool = True,
162
170
  *args,
163
171
  ) -> Tuple[List[jax.Array], jax.Array]:
164
172
  layer_name_to_kvcache_index = dict(layer_name_to_kvcache_index)
@@ -173,13 +181,15 @@ class VllmModelWrapper:
173
181
  # torch_view in order to call the Torch function.
174
182
  original_lora_metadata = replace_lora_metadata(
175
183
  self.model, lora_metadata, self.vllm_config.lora_config)
176
- hidden_states = torch.func.functional_call(
184
+ if not is_first_rank:
185
+ intermediate_tensors = intermediate_tensors.to_torch()
186
+ output_from_torch = torch.func.functional_call(
177
187
  self.model,
178
188
  torch_view(params_and_buffers),
179
189
  kwargs={
180
190
  "input_ids": torch_view(input_ids),
181
191
  "positions": torch_view(attn_metadata.input_positions),
182
- "intermediate_tensors": None,
192
+ "intermediate_tensors": intermediate_tensors,
183
193
  "inputs_embeds": None,
184
194
  },
185
195
  tie_weights=False,
@@ -188,11 +198,13 @@ class VllmModelWrapper:
188
198
  self.vllm_config.lora_config)
189
199
  vllm_model_wrapper_context = get_vllm_model_wrapper_context()
190
200
  new_kv_caches = vllm_model_wrapper_context.kv_caches
191
- # Wrap the hidden_states from torch land into a JaxValue for the jax
192
- # code to consume.
193
- hidden_states = jax_view(hidden_states)
194
-
195
- return new_kv_caches, hidden_states, []
201
+ # Wrap the output(hidden states or intermediate tensor)
202
+ # from torch land into a JaxValue for the jax code to consume.
203
+ if not is_last_rank:
204
+ output = JaxIntermediateTensors.from_torch(output_from_torch)
205
+ else:
206
+ output = jax_view(output_from_torch)
207
+ return new_kv_caches, output, []
196
208
 
197
209
  return step_fun
198
210
 
@@ -1,7 +1,7 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
3
  import os
4
- from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
4
+ from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
5
5
 
6
6
  import jax.numpy as jnp
7
7
  import vllm.envs as vllm_envs
@@ -12,7 +12,7 @@ from vllm.platforms.interface import Platform, PlatformEnum
12
12
  from vllm.sampling_params import SamplingParams, SamplingType
13
13
 
14
14
  from tpu_inference import envs
15
- from tpu_inference.layers.jax.sharding import ShardingConfigManager
15
+ from tpu_inference.layers.common.sharding import ShardingConfigManager
16
16
  from tpu_inference.logger import init_logger
17
17
 
18
18
  if TYPE_CHECKING:
@@ -57,7 +57,8 @@ class TpuPlatform(Platform):
57
57
  def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
58
58
  dtype: jnp.dtype, kv_cache_dtype: Optional[str],
59
59
  block_size: int, use_v1: bool, use_mla: bool,
60
- has_sink: bool, use_sparse: bool) -> str:
60
+ has_sink: bool, use_sparse: bool,
61
+ attn_type: Any) -> str:
61
62
  from vllm.attention.backends.registry import _Backend
62
63
  if selected_backend != _Backend.PALLAS:
63
64
  logger.info("Cannot use %s backend on TPU.", selected_backend)
@@ -184,8 +185,14 @@ class TpuPlatform(Platform):
184
185
 
185
186
  multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
186
187
  if not multihost_backend: # Single host
187
- logger.info("Force using UniProcExecutor for JAX on single host.")
188
- parallel_config.distributed_executor_backend = "uni"
188
+ if parallel_config.pipeline_parallel_size == 1:
189
+ logger.info("Force using UniProcExecutor for JAX on \
190
+ single host without pipeline parallelism.")
191
+ parallel_config.distributed_executor_backend = "uni"
192
+ else:
193
+ logger.info("Force using MultiprocExecutor for JAX on \
194
+ single host with pipeline parallelism.")
195
+ parallel_config.distributed_executor_backend = "mp"
189
196
  elif multihost_backend == "ray":
190
197
  from tpu_inference.executors.ray_distributed_executor import \
191
198
  RayDistributedExecutor
@@ -10,10 +10,10 @@ from jax.sharding import NamedSharding, PartitionSpec
10
10
 
11
11
  from tpu_inference.core.disagg_utils import is_disagg_enabled
12
12
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
13
+ from tpu_inference.layers.common.sharding import ShardingAxisName
13
14
  from tpu_inference.layers.jax.sample.sampling import sample
14
15
  from tpu_inference.layers.jax.sample.sampling_metadata import \
15
16
  TPUSupportedSamplingMetadata
16
- from tpu_inference.layers.jax.sharding import ShardingAxisName
17
17
  from tpu_inference.logger import init_logger
18
18
  from tpu_inference.utils import device_array
19
19
 
@@ -332,13 +332,15 @@ class CompilationManager:
332
332
  index_paddings = self.runner.num_reqs_paddings
333
333
  dp_sharding = NamedSharding(self.runner.mesh,
334
334
  PartitionSpec(ShardingAxisName.ATTN_DATA))
335
+ hidden_states_sharding = NamedSharding(
336
+ self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None))
335
337
  dp_size = self.runner.vllm_config.sharding_config.total_dp_size
336
338
  self._precompile_select_from_array_helper(
337
339
  name="select all logits",
338
340
  source_paddings=self.runner.num_tokens_paddings,
339
341
  indices_paddings=index_paddings,
340
342
  hidden_dim=hsize,
341
- input_sharding=dp_sharding,
343
+ input_sharding=hidden_states_sharding,
342
344
  indices_sharding=dp_sharding if dp_size > 1 else None,
343
345
  )
344
346
 
@@ -9,7 +9,7 @@ from torchax.ops.mappings import t2j_dtype
9
9
 
10
10
  import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
11
11
  import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
12
- from tpu_inference.layers.jax.sharding import ShardingAxisName
12
+ from tpu_inference.layers.common.sharding import ShardingAxisName
13
13
  from tpu_inference.logger import init_logger
14
14
 
15
15
  logger = init_logger(__name__)
@@ -27,7 +27,7 @@ from vllm.v1.core.sched.output import GrammarOutput
27
27
  from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
28
28
  from vllm.v1.kv_cache_interface import KVCacheConfig
29
29
  from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
30
- DraftTokenIds, KVConnectorOutput,
30
+ DraftTokenIds, KVConnectorOutput, LogprobsLists,
31
31
  ModelRunnerOutput)
32
32
  from vllm.v1.request import Request
33
33
  from vllm.v1.spec_decode.ngram_proposer import NgramProposer
@@ -37,15 +37,15 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
37
37
 
38
38
  from tpu_inference import utils as common_utils
39
39
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
40
+ from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
41
+ MESH_AXIS_NAMES_2D,
42
+ ShardingAxisName,
43
+ ShardingConfigManager)
40
44
  from tpu_inference.layers.jax.sample.rejection_sampler import RejectionSampler
41
45
  from tpu_inference.layers.jax.sample.sampling import (compute_logprobs,
42
46
  gather_logprobs, sample)
43
47
  from tpu_inference.layers.jax.sample.sampling_metadata import \
44
48
  TPUSupportedSamplingMetadata
45
- from tpu_inference.layers.jax.sharding import (MESH_AXIS_NAMES,
46
- MESH_AXIS_NAMES_2D,
47
- ShardingAxisName,
48
- ShardingConfigManager)
49
49
  from tpu_inference.logger import init_logger
50
50
  from tpu_inference.models.common.model_loader import get_model
51
51
  from tpu_inference.models.jax.utils.weight_utils import (
@@ -190,6 +190,21 @@ def _substitute_placeholder_token(
190
190
  return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)
191
191
 
192
192
 
193
+ def _reorder_logits_indices(logprobs_lists, logits_indices_selector):
194
+ return LogprobsLists(
195
+ logprob_token_ids=[
196
+ logprobs_lists.logprob_token_ids[i]
197
+ for i in logits_indices_selector
198
+ ],
199
+ logprobs=[logprobs_lists.logprobs[i] for i in logits_indices_selector],
200
+ sampled_token_ranks=[
201
+ logprobs_lists.sampled_token_ranks[i]
202
+ for i in logits_indices_selector
203
+ ],
204
+ cu_num_generated_tokens=logprobs_lists.cu_num_generated_tokens,
205
+ )
206
+
207
+
193
208
  class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
194
209
 
195
210
  def __init__(
@@ -840,7 +855,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
840
855
  logits_indices_selector)
841
856
 
842
857
  if logprobs is not None:
858
+ # Map logprobs back to the pre-dp shuffling order
843
859
  logprobs_lists = logprobs.tolists()
860
+ if logits_indices_selector is not None:
861
+ logprobs_lists = _reorder_logits_indices(
862
+ logprobs_lists, logits_indices_selector)
863
+
844
864
  else:
845
865
  logprobs_lists = None
846
866
 
@@ -908,7 +928,11 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
908
928
  req_state.output_token_ids.extend(sampled_ids)
909
929
 
910
930
  if logprobs is not None:
931
+ # Map logprobs back to the pre-dp shuffling order
911
932
  logprobs_lists = logprobs.tolists()
933
+ if logits_indices_selector is not None:
934
+ logprobs_lists = _reorder_logits_indices(
935
+ logprobs_lists, logits_indices_selector)
912
936
  else:
913
937
  logprobs_lists = None
914
938
 
@@ -1315,10 +1339,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
1315
1339
  seq_lens_cpu = seq_lens
1316
1340
 
1317
1341
  (input_ids, positions, block_tables, query_start_loc, seq_lens,
1318
- logits_indices, request_distribution, logits_indices) = device_array(
1342
+ logits_indices, request_distribution) = device_array(
1319
1343
  self.mesh,
1320
1344
  (input_ids, positions, block_tables, query_start_loc, seq_lens,
1321
- logits_indices, request_distribution, logits_indices),
1345
+ logits_indices, request_distribution),
1322
1346
  sharding=data_parallel_attn_sharding,
1323
1347
  )
1324
1348
  # Async scheduling: substitute placeholder tokens for DP
tpu_inference/utils.py CHANGED
@@ -132,8 +132,8 @@ def pathways_hbm_usage_gb(devices: Any) -> List[Tuple[float, float]]:
132
132
  hbm_used = defaultdict(int)
133
133
  hbm_limit = get_device_hbm_limit()
134
134
  for array in live_arrays:
135
- for buffer in array.device_buffers:
136
- hbm_used[buffer.device] += buffer.nbytes
135
+ for buffer in array.addressable_shards:
136
+ hbm_used[buffer.data.device] += buffer.data.nbytes
137
137
  return [(hbm_used[device], hbm_limit) for device in devices]
138
138
 
139
139
 
@@ -25,7 +25,7 @@ from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
25
25
  from tpu_inference import envs, utils
26
26
  from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
27
27
  get_node_id)
28
- from tpu_inference.layers.jax.sharding import ShardingConfigManager
28
+ from tpu_inference.layers.common.sharding import ShardingConfigManager
29
29
  from tpu_inference.logger import init_logger
30
30
  from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes
31
31
  from tpu_inference.runner.tpu_runner import TPUModelRunner