tpu-inference 0.11.1.dev202511130813__py3-none-any.whl → 0.11.1.dev202511180814__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/test_envs.py +182 -0
- tests/test_utils.py +23 -14
- tpu_inference/core/core_tpu.py +17 -9
- tpu_inference/executors/ray_distributed_executor.py +24 -11
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +33 -10
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +7 -0
- tpu_inference/layers/{jax → common}/attention_interface.py +1 -1
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/jax/attention/attention.py +1 -1
- tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
- tpu_inference/layers/jax/sample/sampling.py +2 -2
- tpu_inference/layers/vllm/attention.py +1 -1
- tpu_inference/layers/vllm/quantization/__init__.py +7 -3
- tpu_inference/layers/vllm/quantization/awq.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -2
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +4 -3
- tpu_inference/models/common/model_loader.py +3 -2
- tpu_inference/models/jax/llama3.py +2 -2
- tpu_inference/models/jax/phi3.py +1 -1
- tpu_inference/models/jax/qwen2.py +1 -1
- tpu_inference/models/jax/qwen2_5_vl.py +2 -2
- tpu_inference/models/jax/qwen3.py +1 -1
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -10
- tpu_inference/platforms/tpu_platform.py +12 -5
- tpu_inference/runner/compilation_manager.py +4 -2
- tpu_inference/runner/kv_cache.py +1 -1
- tpu_inference/runner/tpu_runner.py +31 -7
- tpu_inference/utils.py +2 -2
- tpu_inference/worker/tpu_worker.py +1 -1
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/METADATA +1 -1
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/RECORD +37 -34
- /tpu_inference/layers/{jax → common}/binary_search.py +0 -0
- /tpu_inference/layers/{jax → common}/sharding.py +0 -0
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511180814.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
from typing import Callable, Optional, Union
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import torch
|
|
6
|
+
from jax.experimental.layout import Format, Layout
|
|
7
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
8
|
+
from torch.nn.parameter import Parameter
|
|
9
|
+
from torchax.interop import jax_view, torch_view
|
|
10
|
+
from torchax.ops.mappings import t2j
|
|
11
|
+
from vllm.logger import init_logger
|
|
12
|
+
from vllm.model_executor.layers.fused_moe.config import (
|
|
13
|
+
FusedMoEConfig, FusedMoEQuantConfig, biased_moe_quant_config)
|
|
14
|
+
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
|
15
|
+
FusedMoEMethodBase)
|
|
16
|
+
from vllm.model_executor.layers.linear import LinearBase
|
|
17
|
+
from vllm.model_executor.layers.quantization import \
|
|
18
|
+
register_quantization_config
|
|
19
|
+
from vllm.model_executor.layers.quantization.base_config import \
|
|
20
|
+
QuantizeMethodBase
|
|
21
|
+
from vllm.model_executor.layers.quantization.mxfp4 import (Mxfp4Backend,
|
|
22
|
+
Mxfp4Config,
|
|
23
|
+
Mxfp4MoEMethod)
|
|
24
|
+
from vllm.model_executor.layers.quantization.utils.quant_utils import \
|
|
25
|
+
is_layer_skipped
|
|
26
|
+
|
|
27
|
+
from tpu_inference.layers.common.quant_methods import (MXFP4,
|
|
28
|
+
get_tpu_quant_method)
|
|
29
|
+
from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
|
|
30
|
+
from tpu_inference.layers.vllm.linear_common import \
|
|
31
|
+
reorder_concatenated_tensor_for_sharding
|
|
32
|
+
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
33
|
+
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
34
|
+
VllmUnquantizedLinearMethod
|
|
35
|
+
|
|
36
|
+
MXFP4_BLOCK_SIZE = 32
|
|
37
|
+
|
|
38
|
+
P = PartitionSpec
|
|
39
|
+
logger = init_logger(__name__)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# TODO(kyuyeunk): Move these functions into a common utility file.
|
|
43
|
+
def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
|
|
44
|
+
assert u8_packed_e2m1.dtype == jnp.uint8
|
|
45
|
+
e2m1 = jax.lax.bitcast_convert_type(u8_packed_e2m1, jnp.float4_e2m1fn)
|
|
46
|
+
# bitcast creates one more dimension that splits 8 bits into two e2m1.
|
|
47
|
+
# we flatten them with the last dim.
|
|
48
|
+
return jnp.reshape(e2m1, e2m1.shape[:-2] + (-1, ))
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def e8m0_to_fp32(u8: jax.Array) -> jax.Array:
|
|
52
|
+
e8_finfo = jnp.finfo(jnp.float8_e8m0fnu)
|
|
53
|
+
exponents = u8.astype(jnp.int32) + e8_finfo.minexp
|
|
54
|
+
ones = jnp.ones_like(u8, dtype=jnp.float32)
|
|
55
|
+
return jnp.ldexp(ones, exponents)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def dequantize_block_weight(weight: jax.Array,
|
|
59
|
+
scale: jax.Array,
|
|
60
|
+
block_size: int,
|
|
61
|
+
out_dtype: jnp.dtype = jnp.bfloat16) -> jax.Array:
|
|
62
|
+
orig_shape = weight.shape
|
|
63
|
+
weight_block = weight.reshape(orig_shape[:-1] + (-1, block_size))
|
|
64
|
+
weight_dequantized = weight_block.astype(jnp.float32) * jnp.expand_dims(
|
|
65
|
+
scale, -1)
|
|
66
|
+
return weight_dequantized.reshape(orig_shape).astype(out_dtype)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@register_quantization_config(get_tpu_quant_method(MXFP4))
|
|
70
|
+
class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def get_name(cls):
|
|
74
|
+
return MXFP4
|
|
75
|
+
|
|
76
|
+
def get_quant_method(self, layer: torch.nn.Module,
|
|
77
|
+
prefix: str) -> Optional["QuantizeMethodBase"]:
|
|
78
|
+
from vllm.attention.layer import Attention # Avoid circular import
|
|
79
|
+
|
|
80
|
+
if isinstance(layer, LinearBase):
|
|
81
|
+
linear_config = self.get_linear_config(layer)
|
|
82
|
+
if self.ignored_layers and is_layer_skipped(
|
|
83
|
+
prefix=prefix,
|
|
84
|
+
ignored_layers=self.ignored_layers,
|
|
85
|
+
fused_mapping=self.packed_modules_mapping,
|
|
86
|
+
):
|
|
87
|
+
return VllmUnquantizedLinearMethod(linear_config)
|
|
88
|
+
# TODO: Add support for MXFP4 Linear Method.
|
|
89
|
+
# MXFP4 LinearMethod is available in AMD-Quark, refer to that
|
|
90
|
+
# implementation if you are interested in enabling MXFP4 here.
|
|
91
|
+
logger.warning_once(
|
|
92
|
+
"MXFP4 linear layer is not implemented - falling back to "
|
|
93
|
+
"UnquantizedLinearMethod.")
|
|
94
|
+
return VllmUnquantizedLinearMethod(linear_config)
|
|
95
|
+
elif isinstance(layer, FusedMoE):
|
|
96
|
+
return VllmMxfp4MoEMethod(layer.moe_config, self.mesh)
|
|
97
|
+
elif isinstance(layer, Attention):
|
|
98
|
+
# TODO: Add support for MXFP4 Attention.
|
|
99
|
+
logger.warning_once("MXFP4 attention layer is not implemented. "
|
|
100
|
+
"Skipping quantization for this layer.")
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
105
|
+
|
|
106
|
+
def __init__(self, moe: FusedMoEConfig, mesh: Mesh):
|
|
107
|
+
FusedMoEMethodBase.__init__(self, moe)
|
|
108
|
+
|
|
109
|
+
# We piggyback on triton implementation as it applies minimal hardware
|
|
110
|
+
# specific post processing to the weights.
|
|
111
|
+
self.mxfp4_backend = Mxfp4Backend.TRITON
|
|
112
|
+
self.mesh = mesh
|
|
113
|
+
|
|
114
|
+
def get_fused_moe_quant_config(
|
|
115
|
+
self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
|
|
116
|
+
# Because we have dequantized weights, we only need biased moe config.
|
|
117
|
+
# TODO(kyuyeunk): Add native support for MXFP4.
|
|
118
|
+
return biased_moe_quant_config(
|
|
119
|
+
layer.w13_bias,
|
|
120
|
+
layer.w2_bias,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
def process_weights_after_loading(self, layer: torch.nn.Module):
|
|
124
|
+
assert isinstance(layer, FusedMoE)
|
|
125
|
+
|
|
126
|
+
w13_weight = u8_unpack_e2m1(t2j(layer.w13_weight, use_dlpack=False))
|
|
127
|
+
w13_weight_scale = e8m0_to_fp32(
|
|
128
|
+
t2j(layer.w13_weight_scale, use_dlpack=False))
|
|
129
|
+
w13_bias = t2j(layer.w13_bias, use_dlpack=False)
|
|
130
|
+
|
|
131
|
+
w2_weight = u8_unpack_e2m1(t2j(layer.w2_weight, use_dlpack=False))
|
|
132
|
+
w2_weight_scale = e8m0_to_fp32(
|
|
133
|
+
t2j(layer.w2_weight_scale, use_dlpack=False))
|
|
134
|
+
w2_bias = t2j(layer.w2_bias, use_dlpack=False)
|
|
135
|
+
|
|
136
|
+
# We dequantize fp4 weights into bf16.
|
|
137
|
+
# TODO(kyuyeunk): Add native support for MXFP4.
|
|
138
|
+
w13_weight = dequantize_block_weight(w13_weight, w13_weight_scale,
|
|
139
|
+
MXFP4_BLOCK_SIZE, jnp.bfloat16)
|
|
140
|
+
w2_weight = dequantize_block_weight(w2_weight, w2_weight_scale,
|
|
141
|
+
MXFP4_BLOCK_SIZE, jnp.bfloat16)
|
|
142
|
+
|
|
143
|
+
# Because we have dequantized weights, scales are not used anymore.
|
|
144
|
+
delattr(layer, "w13_weight_scale")
|
|
145
|
+
delattr(layer, "w2_weight_scale")
|
|
146
|
+
|
|
147
|
+
if layer.activation == "swigluoai":
|
|
148
|
+
# When using swigluoai, vLLM splits gmm output in a interleaved way.
|
|
149
|
+
# However, interleaved split is not performant on TPU. Therefore,
|
|
150
|
+
# we preprocess the weight so that splitting gmm output by middle
|
|
151
|
+
# can still get the same result.
|
|
152
|
+
w1_weight = w13_weight[:, ::2, :]
|
|
153
|
+
w3_weight = w13_weight[:, 1::2, :]
|
|
154
|
+
w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
|
|
155
|
+
|
|
156
|
+
w1_bias = w13_bias[:, ::2]
|
|
157
|
+
w3_bias = w13_bias[:, 1::2]
|
|
158
|
+
w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
|
|
159
|
+
|
|
160
|
+
# TODO(kyuyeunk): Add weight processing logic for the new kernel.
|
|
161
|
+
if layer.use_ep:
|
|
162
|
+
w13_weight = jax.device_put(
|
|
163
|
+
w13_weight,
|
|
164
|
+
Format(Layout((0, 1, 2)),
|
|
165
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
166
|
+
w2_weight = jax.device_put(
|
|
167
|
+
w2_weight,
|
|
168
|
+
Format(Layout((0, 1, 2)),
|
|
169
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
170
|
+
|
|
171
|
+
w13_bias = jax.device_put(
|
|
172
|
+
w13_bias,
|
|
173
|
+
Format(Layout((0, 1)),
|
|
174
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
175
|
+
w2_bias = jax.device_put(
|
|
176
|
+
w2_bias,
|
|
177
|
+
Format(Layout((0, 1)),
|
|
178
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
179
|
+
|
|
180
|
+
else:
|
|
181
|
+
intermediate_size = w13_weight.shape[1] // 2
|
|
182
|
+
assert intermediate_size == w2_weight.shape[-1]
|
|
183
|
+
output_sizes = [intermediate_size, intermediate_size]
|
|
184
|
+
n_shards = self.mesh.shape["model"]
|
|
185
|
+
assert intermediate_size % n_shards == 0
|
|
186
|
+
w13_weight = reorder_concatenated_tensor_for_sharding(w13_weight,
|
|
187
|
+
output_sizes,
|
|
188
|
+
n_shards,
|
|
189
|
+
dim=1)
|
|
190
|
+
w13_weight = jax.device_put(
|
|
191
|
+
w13_weight,
|
|
192
|
+
Format(Layout((0, 1, 2)),
|
|
193
|
+
NamedSharding(self.mesh, P(None, "model", None))))
|
|
194
|
+
w2_weight = jax.device_put(
|
|
195
|
+
w2_weight,
|
|
196
|
+
Format(Layout((0, 1, 2)),
|
|
197
|
+
NamedSharding(self.mesh, P(None, None, "model"))))
|
|
198
|
+
|
|
199
|
+
w13_bias = reorder_concatenated_tensor_for_sharding(w13_bias,
|
|
200
|
+
output_sizes,
|
|
201
|
+
n_shards,
|
|
202
|
+
dim=1)
|
|
203
|
+
w13_bias = jax.device_put(
|
|
204
|
+
w13_bias,
|
|
205
|
+
Format(Layout((0, 1)),
|
|
206
|
+
NamedSharding(self.mesh, P(None, "model"))))
|
|
207
|
+
w2_bias = jax.device_put(
|
|
208
|
+
w2_bias,
|
|
209
|
+
Format(Layout((0, 1)), NamedSharding(self.mesh, P(None,
|
|
210
|
+
None))))
|
|
211
|
+
|
|
212
|
+
layer.w13_weight = Parameter(torch_view(w13_weight),
|
|
213
|
+
requires_grad=False)
|
|
214
|
+
layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
|
|
215
|
+
|
|
216
|
+
layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
|
|
217
|
+
layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
|
|
218
|
+
|
|
219
|
+
pass
|
|
220
|
+
|
|
221
|
+
def apply(
|
|
222
|
+
self,
|
|
223
|
+
layer: torch.nn.Module,
|
|
224
|
+
x: torch.Tensor,
|
|
225
|
+
router_logits: torch.Tensor,
|
|
226
|
+
top_k: int,
|
|
227
|
+
renormalize: bool,
|
|
228
|
+
use_grouped_topk: bool = False,
|
|
229
|
+
topk_group: Optional[int] = None,
|
|
230
|
+
num_expert_group: Optional[int] = None,
|
|
231
|
+
global_num_experts: int = -1,
|
|
232
|
+
expert_map: Optional[torch.Tensor] = None,
|
|
233
|
+
custom_routing_function: Optional[Callable] = None,
|
|
234
|
+
scoring_func: str = "softmax",
|
|
235
|
+
routed_scaling_factor: float = 1.0,
|
|
236
|
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
237
|
+
apply_router_weight_on_input: bool = False,
|
|
238
|
+
activation: str = "silu",
|
|
239
|
+
enable_eplb: bool = False,
|
|
240
|
+
expert_load_view: Optional[torch.Tensor] = None,
|
|
241
|
+
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
242
|
+
logical_replica_count: Optional[torch.Tensor] = None,
|
|
243
|
+
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
244
|
+
assert isinstance(layer, FusedMoE)
|
|
245
|
+
if scoring_func != "softmax":
|
|
246
|
+
raise NotImplementedError(
|
|
247
|
+
"Only softmax is supported for scoring_func")
|
|
248
|
+
|
|
249
|
+
# Use the original implementation
|
|
250
|
+
output = fused_moe_func_padded(
|
|
251
|
+
jax_view(x),
|
|
252
|
+
jax_view(layer.w13_weight),
|
|
253
|
+
jax_view(layer.w2_weight),
|
|
254
|
+
jax_view(layer.w13_bias) if self.moe.has_bias else None,
|
|
255
|
+
jax_view(layer.w2_bias) if self.moe.has_bias else None,
|
|
256
|
+
jax_view(router_logits),
|
|
257
|
+
topk=top_k,
|
|
258
|
+
global_num_experts=global_num_experts,
|
|
259
|
+
renormalize=renormalize,
|
|
260
|
+
reduce_results=layer.reduce_results,
|
|
261
|
+
mesh=self.mesh,
|
|
262
|
+
use_ep=layer.use_ep,
|
|
263
|
+
activation=activation,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
return torch_view(output)
|
|
@@ -23,6 +23,8 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|
|
23
23
|
|
|
24
24
|
from tpu_inference import envs
|
|
25
25
|
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
|
|
26
|
+
from tpu_inference.layers.common.quant_methods import (UNQUANTIZED,
|
|
27
|
+
get_tpu_quant_method)
|
|
26
28
|
from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
|
|
27
29
|
from tpu_inference.layers.vllm.linear_common import (
|
|
28
30
|
reorder_concatenated_tensor_for_sharding,
|
|
@@ -34,12 +36,12 @@ P = PartitionSpec
|
|
|
34
36
|
logger = init_logger(__name__)
|
|
35
37
|
|
|
36
38
|
|
|
37
|
-
@register_quantization_config(
|
|
39
|
+
@register_quantization_config(get_tpu_quant_method(UNQUANTIZED))
|
|
38
40
|
class VllmUnquantizedConfig(QuantizationConfig, JaxCommonConfig):
|
|
39
41
|
|
|
40
42
|
@classmethod
|
|
41
43
|
def get_name(cls) -> str:
|
|
42
|
-
return
|
|
44
|
+
return UNQUANTIZED
|
|
43
45
|
|
|
44
46
|
@classmethod
|
|
45
47
|
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
|
@@ -189,7 +191,6 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
189
191
|
|
|
190
192
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
191
193
|
assert isinstance(layer, FusedMoE)
|
|
192
|
-
|
|
193
194
|
w13_weight = t2j(layer.w13_weight, use_dlpack=False)
|
|
194
195
|
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
|
|
195
196
|
|
|
@@ -11,7 +11,7 @@ from vllm.config import VllmConfig
|
|
|
11
11
|
from vllm.utils.func_utils import supports_kw
|
|
12
12
|
|
|
13
13
|
from tpu_inference import envs
|
|
14
|
-
from tpu_inference.layers.
|
|
14
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
15
15
|
from tpu_inference.logger import init_logger
|
|
16
16
|
from tpu_inference.models.jax.utils.quantization.quantization_utils import (
|
|
17
17
|
apply_qwix_on_abstract_model, apply_qwix_quantization,
|
|
@@ -242,10 +242,11 @@ def get_flax_model(
|
|
|
242
242
|
model = nnx.merge(graphdef, state)
|
|
243
243
|
return model.get_multimodal_embeddings(image_grid_thw, **kwargs)
|
|
244
244
|
|
|
245
|
+
embed_sharding = NamedSharding(mesh, PartitionSpec(None))
|
|
245
246
|
# This function will calculates the embeddings of input texts and then merge with the image embeddings
|
|
246
247
|
@functools.partial(
|
|
247
248
|
jax.jit,
|
|
248
|
-
out_shardings=(
|
|
249
|
+
out_shardings=(embed_sharding),
|
|
249
250
|
)
|
|
250
251
|
def run_get_input_embeddings(graphdef, state, *args, **kwargs):
|
|
251
252
|
model = nnx.merge(graphdef, state)
|
|
@@ -8,10 +8,10 @@ from transformers import LlamaConfig, modeling_flax_utils
|
|
|
8
8
|
from vllm.config import VllmConfig
|
|
9
9
|
|
|
10
10
|
from tpu_inference import utils
|
|
11
|
+
from tpu_inference.layers.common.attention_interface import attention
|
|
11
12
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
12
|
-
from tpu_inference.layers.
|
|
13
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
13
14
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
14
|
-
from tpu_inference.layers.jax.sharding import ShardingAxisName
|
|
15
15
|
from tpu_inference.logger import init_logger
|
|
16
16
|
from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
|
|
17
17
|
load_hf_weights)
|
tpu_inference/models/jax/phi3.py
CHANGED
|
@@ -8,8 +8,8 @@ from transformers import Phi3Config, modeling_flax_utils
|
|
|
8
8
|
from vllm.config import VllmConfig
|
|
9
9
|
|
|
10
10
|
from tpu_inference import utils
|
|
11
|
+
from tpu_inference.layers.common.attention_interface import attention
|
|
11
12
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
12
|
-
from tpu_inference.layers.jax.attention_interface import attention
|
|
13
13
|
from tpu_inference.layers.jax.rope_interface import apply_longrope, apply_rope
|
|
14
14
|
from tpu_inference.logger import init_logger
|
|
15
15
|
from tpu_inference.models.jax.utils.weight_utils import (MetadataMap,
|
|
@@ -8,8 +8,8 @@ from transformers import Qwen2Config, modeling_flax_utils
|
|
|
8
8
|
from vllm.config import VllmConfig
|
|
9
9
|
|
|
10
10
|
from tpu_inference import utils
|
|
11
|
+
from tpu_inference.layers.common.attention_interface import attention
|
|
11
12
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
12
|
-
from tpu_inference.layers.jax.attention_interface import attention
|
|
13
13
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
14
14
|
from tpu_inference.logger import init_logger
|
|
15
15
|
from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
|
|
@@ -14,9 +14,9 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
|
|
14
14
|
from vllm.config import VllmConfig
|
|
15
15
|
|
|
16
16
|
from tpu_inference import utils as utils
|
|
17
|
-
from tpu_inference.layers.common.
|
|
18
|
-
from tpu_inference.layers.jax.attention_interface import \
|
|
17
|
+
from tpu_inference.layers.common.attention_interface import \
|
|
19
18
|
sharded_flash_attention
|
|
19
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
20
20
|
from tpu_inference.logger import init_logger
|
|
21
21
|
from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
|
|
22
22
|
# from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
|
@@ -8,8 +8,8 @@ from transformers import Qwen3Config
|
|
|
8
8
|
from vllm.config import VllmConfig
|
|
9
9
|
|
|
10
10
|
from tpu_inference import utils
|
|
11
|
+
from tpu_inference.layers.common.attention_interface import attention
|
|
11
12
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
12
|
-
from tpu_inference.layers.jax.attention_interface import attention
|
|
13
13
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
14
14
|
from tpu_inference.logger import init_logger
|
|
15
15
|
from tpu_inference.models.jax.qwen2 import Qwen2DecoderLayer
|
|
@@ -25,6 +25,8 @@ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
|
25
25
|
from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
|
|
26
26
|
from tpu_inference.layers.vllm.sharding import shard_model_to_tpu
|
|
27
27
|
from tpu_inference.logger import init_logger
|
|
28
|
+
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
29
|
+
JaxIntermediateTensors
|
|
28
30
|
from tpu_inference.models.vllm.vllm_model_wrapper_context import (
|
|
29
31
|
get_vllm_model_wrapper_context, set_vllm_model_wrapper_context)
|
|
30
32
|
from tpu_inference.runner.lora_utils import replace_lora_metadata
|
|
@@ -89,13 +91,14 @@ class VllmModelWrapper:
|
|
|
89
91
|
slice_config = self.vllm_config.device_config.slice
|
|
90
92
|
modified_slice_config = True
|
|
91
93
|
self.vllm_config.device_config.slice = None
|
|
94
|
+
self.vllm_config.compilation_config.static_forward_context.clear()
|
|
95
|
+
|
|
92
96
|
vllm_config_for_load = copy.deepcopy(self.vllm_config)
|
|
93
97
|
if modified_slice_config:
|
|
94
98
|
self.vllm_config.device_config.slice = slice_config
|
|
95
99
|
assert self.vllm_config.model_config.dtype in TORCH_DTYPE_TO_JAX, "The model_config.dtype must be a PyTorch dtype."
|
|
96
100
|
vllm_config_for_load.device_config.device = "cpu"
|
|
97
101
|
# Clearing the cached compilation config, otherwise vllm model init will fail
|
|
98
|
-
vllm_config_for_load.compilation_config.static_forward_context.clear()
|
|
99
102
|
|
|
100
103
|
# When expert parallelism is enabled, vLLM loads weight in sharding
|
|
101
104
|
# aware manner. Since tpu-inference has its own sharding logic, this
|
|
@@ -117,7 +120,8 @@ class VllmModelWrapper:
|
|
|
117
120
|
|
|
118
121
|
# Load the vLLM model and wrap it into a new model whose forward
|
|
119
122
|
# function can calculate the hidden_state and logits.
|
|
120
|
-
|
|
123
|
+
available_devices = self.mesh.devices.flatten()
|
|
124
|
+
with load_context, jax.default_device(available_devices[0]):
|
|
121
125
|
vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
|
|
122
126
|
lora_manager = None
|
|
123
127
|
if vllm_config_for_load.lora_config is not None:
|
|
@@ -149,7 +153,8 @@ class VllmModelWrapper:
|
|
|
149
153
|
"xla_tpu_reduce_scatter_collective_matmul_mode":
|
|
150
154
|
"post_spmd_conservative"
|
|
151
155
|
},
|
|
152
|
-
static_argnames=("layer_name_to_kvcache_index",
|
|
156
|
+
static_argnames=("layer_name_to_kvcache_index", "is_first_rank",
|
|
157
|
+
"is_last_rank"),
|
|
153
158
|
)
|
|
154
159
|
def step_fun(
|
|
155
160
|
params_and_buffers, # This has been wrapped into torchax TorchValue
|
|
@@ -159,6 +164,9 @@ class VllmModelWrapper:
|
|
|
159
164
|
input_embeds: jax.Array,
|
|
160
165
|
layer_name_to_kvcache_index: Sequence[Tuple[str, int]],
|
|
161
166
|
lora_metadata,
|
|
167
|
+
intermediate_tensors: JaxIntermediateTensors = None,
|
|
168
|
+
is_first_rank: bool = True,
|
|
169
|
+
is_last_rank: bool = True,
|
|
162
170
|
*args,
|
|
163
171
|
) -> Tuple[List[jax.Array], jax.Array]:
|
|
164
172
|
layer_name_to_kvcache_index = dict(layer_name_to_kvcache_index)
|
|
@@ -173,13 +181,15 @@ class VllmModelWrapper:
|
|
|
173
181
|
# torch_view in order to call the Torch function.
|
|
174
182
|
original_lora_metadata = replace_lora_metadata(
|
|
175
183
|
self.model, lora_metadata, self.vllm_config.lora_config)
|
|
176
|
-
|
|
184
|
+
if not is_first_rank:
|
|
185
|
+
intermediate_tensors = intermediate_tensors.to_torch()
|
|
186
|
+
output_from_torch = torch.func.functional_call(
|
|
177
187
|
self.model,
|
|
178
188
|
torch_view(params_and_buffers),
|
|
179
189
|
kwargs={
|
|
180
190
|
"input_ids": torch_view(input_ids),
|
|
181
191
|
"positions": torch_view(attn_metadata.input_positions),
|
|
182
|
-
"intermediate_tensors":
|
|
192
|
+
"intermediate_tensors": intermediate_tensors,
|
|
183
193
|
"inputs_embeds": None,
|
|
184
194
|
},
|
|
185
195
|
tie_weights=False,
|
|
@@ -188,11 +198,13 @@ class VllmModelWrapper:
|
|
|
188
198
|
self.vllm_config.lora_config)
|
|
189
199
|
vllm_model_wrapper_context = get_vllm_model_wrapper_context()
|
|
190
200
|
new_kv_caches = vllm_model_wrapper_context.kv_caches
|
|
191
|
-
# Wrap the
|
|
192
|
-
# code to consume.
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
201
|
+
# Wrap the output(hidden states or intermediate tensor)
|
|
202
|
+
# from torch land into a JaxValue for the jax code to consume.
|
|
203
|
+
if not is_last_rank:
|
|
204
|
+
output = JaxIntermediateTensors.from_torch(output_from_torch)
|
|
205
|
+
else:
|
|
206
|
+
output = jax_view(output_from_torch)
|
|
207
|
+
return new_kv_caches, output, []
|
|
196
208
|
|
|
197
209
|
return step_fun
|
|
198
210
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
|
|
3
3
|
import os
|
|
4
|
-
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
|
|
5
5
|
|
|
6
6
|
import jax.numpy as jnp
|
|
7
7
|
import vllm.envs as vllm_envs
|
|
@@ -12,7 +12,7 @@ from vllm.platforms.interface import Platform, PlatformEnum
|
|
|
12
12
|
from vllm.sampling_params import SamplingParams, SamplingType
|
|
13
13
|
|
|
14
14
|
from tpu_inference import envs
|
|
15
|
-
from tpu_inference.layers.
|
|
15
|
+
from tpu_inference.layers.common.sharding import ShardingConfigManager
|
|
16
16
|
from tpu_inference.logger import init_logger
|
|
17
17
|
|
|
18
18
|
if TYPE_CHECKING:
|
|
@@ -57,7 +57,8 @@ class TpuPlatform(Platform):
|
|
|
57
57
|
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
|
|
58
58
|
dtype: jnp.dtype, kv_cache_dtype: Optional[str],
|
|
59
59
|
block_size: int, use_v1: bool, use_mla: bool,
|
|
60
|
-
has_sink: bool, use_sparse: bool
|
|
60
|
+
has_sink: bool, use_sparse: bool,
|
|
61
|
+
attn_type: Any) -> str:
|
|
61
62
|
from vllm.attention.backends.registry import _Backend
|
|
62
63
|
if selected_backend != _Backend.PALLAS:
|
|
63
64
|
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
|
@@ -184,8 +185,14 @@ class TpuPlatform(Platform):
|
|
|
184
185
|
|
|
185
186
|
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
|
|
186
187
|
if not multihost_backend: # Single host
|
|
187
|
-
|
|
188
|
-
|
|
188
|
+
if parallel_config.pipeline_parallel_size == 1:
|
|
189
|
+
logger.info("Force using UniProcExecutor for JAX on \
|
|
190
|
+
single host without pipeline parallelism.")
|
|
191
|
+
parallel_config.distributed_executor_backend = "uni"
|
|
192
|
+
else:
|
|
193
|
+
logger.info("Force using MultiprocExecutor for JAX on \
|
|
194
|
+
single host with pipeline parallelism.")
|
|
195
|
+
parallel_config.distributed_executor_backend = "mp"
|
|
189
196
|
elif multihost_backend == "ray":
|
|
190
197
|
from tpu_inference.executors.ray_distributed_executor import \
|
|
191
198
|
RayDistributedExecutor
|
|
@@ -10,10 +10,10 @@ from jax.sharding import NamedSharding, PartitionSpec
|
|
|
10
10
|
|
|
11
11
|
from tpu_inference.core.disagg_utils import is_disagg_enabled
|
|
12
12
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
13
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
13
14
|
from tpu_inference.layers.jax.sample.sampling import sample
|
|
14
15
|
from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
15
16
|
TPUSupportedSamplingMetadata
|
|
16
|
-
from tpu_inference.layers.jax.sharding import ShardingAxisName
|
|
17
17
|
from tpu_inference.logger import init_logger
|
|
18
18
|
from tpu_inference.utils import device_array
|
|
19
19
|
|
|
@@ -332,13 +332,15 @@ class CompilationManager:
|
|
|
332
332
|
index_paddings = self.runner.num_reqs_paddings
|
|
333
333
|
dp_sharding = NamedSharding(self.runner.mesh,
|
|
334
334
|
PartitionSpec(ShardingAxisName.ATTN_DATA))
|
|
335
|
+
hidden_states_sharding = NamedSharding(
|
|
336
|
+
self.runner.mesh, PartitionSpec(ShardingAxisName.ATTN_DATA, None))
|
|
335
337
|
dp_size = self.runner.vllm_config.sharding_config.total_dp_size
|
|
336
338
|
self._precompile_select_from_array_helper(
|
|
337
339
|
name="select all logits",
|
|
338
340
|
source_paddings=self.runner.num_tokens_paddings,
|
|
339
341
|
indices_paddings=index_paddings,
|
|
340
342
|
hidden_dim=hsize,
|
|
341
|
-
input_sharding=
|
|
343
|
+
input_sharding=hidden_states_sharding,
|
|
342
344
|
indices_sharding=dp_sharding if dp_size > 1 else None,
|
|
343
345
|
)
|
|
344
346
|
|
tpu_inference/runner/kv_cache.py
CHANGED
|
@@ -9,7 +9,7 @@ from torchax.ops.mappings import t2j_dtype
|
|
|
9
9
|
|
|
10
10
|
import tpu_inference.kernels.ragged_paged_attention.v3.kernel as rpa
|
|
11
11
|
import tpu_inference.kernels.ragged_paged_attention.v3.kernel_hd64 as rpa_hd64
|
|
12
|
-
from tpu_inference.layers.
|
|
12
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
13
13
|
from tpu_inference.logger import init_logger
|
|
14
14
|
|
|
15
15
|
logger = init_logger(__name__)
|
|
@@ -27,7 +27,7 @@ from vllm.v1.core.sched.output import GrammarOutput
|
|
|
27
27
|
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
|
|
28
28
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
|
29
29
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
|
30
|
-
DraftTokenIds, KVConnectorOutput,
|
|
30
|
+
DraftTokenIds, KVConnectorOutput, LogprobsLists,
|
|
31
31
|
ModelRunnerOutput)
|
|
32
32
|
from vllm.v1.request import Request
|
|
33
33
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
|
@@ -37,15 +37,15 @@ from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
|
|
37
37
|
|
|
38
38
|
from tpu_inference import utils as common_utils
|
|
39
39
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
40
|
+
from tpu_inference.layers.common.sharding import (MESH_AXIS_NAMES,
|
|
41
|
+
MESH_AXIS_NAMES_2D,
|
|
42
|
+
ShardingAxisName,
|
|
43
|
+
ShardingConfigManager)
|
|
40
44
|
from tpu_inference.layers.jax.sample.rejection_sampler import RejectionSampler
|
|
41
45
|
from tpu_inference.layers.jax.sample.sampling import (compute_logprobs,
|
|
42
46
|
gather_logprobs, sample)
|
|
43
47
|
from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
44
48
|
TPUSupportedSamplingMetadata
|
|
45
|
-
from tpu_inference.layers.jax.sharding import (MESH_AXIS_NAMES,
|
|
46
|
-
MESH_AXIS_NAMES_2D,
|
|
47
|
-
ShardingAxisName,
|
|
48
|
-
ShardingConfigManager)
|
|
49
49
|
from tpu_inference.logger import init_logger
|
|
50
50
|
from tpu_inference.models.common.model_loader import get_model
|
|
51
51
|
from tpu_inference.models.jax.utils.weight_utils import (
|
|
@@ -190,6 +190,21 @@ def _substitute_placeholder_token(
|
|
|
190
190
|
return input_ids.at[token_in_tpu_cur_input_indices].set(update_values)
|
|
191
191
|
|
|
192
192
|
|
|
193
|
+
def _reorder_logits_indices(logprobs_lists, logits_indices_selector):
|
|
194
|
+
return LogprobsLists(
|
|
195
|
+
logprob_token_ids=[
|
|
196
|
+
logprobs_lists.logprob_token_ids[i]
|
|
197
|
+
for i in logits_indices_selector
|
|
198
|
+
],
|
|
199
|
+
logprobs=[logprobs_lists.logprobs[i] for i in logits_indices_selector],
|
|
200
|
+
sampled_token_ranks=[
|
|
201
|
+
logprobs_lists.sampled_token_ranks[i]
|
|
202
|
+
for i in logits_indices_selector
|
|
203
|
+
],
|
|
204
|
+
cu_num_generated_tokens=logprobs_lists.cu_num_generated_tokens,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
|
|
193
208
|
class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
194
209
|
|
|
195
210
|
def __init__(
|
|
@@ -840,7 +855,12 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
840
855
|
logits_indices_selector)
|
|
841
856
|
|
|
842
857
|
if logprobs is not None:
|
|
858
|
+
# Map logprobs back to the pre-dp shuffling order
|
|
843
859
|
logprobs_lists = logprobs.tolists()
|
|
860
|
+
if logits_indices_selector is not None:
|
|
861
|
+
logprobs_lists = _reorder_logits_indices(
|
|
862
|
+
logprobs_lists, logits_indices_selector)
|
|
863
|
+
|
|
844
864
|
else:
|
|
845
865
|
logprobs_lists = None
|
|
846
866
|
|
|
@@ -908,7 +928,11 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
908
928
|
req_state.output_token_ids.extend(sampled_ids)
|
|
909
929
|
|
|
910
930
|
if logprobs is not None:
|
|
931
|
+
# Map logprobs back to the pre-dp shuffling order
|
|
911
932
|
logprobs_lists = logprobs.tolists()
|
|
933
|
+
if logits_indices_selector is not None:
|
|
934
|
+
logprobs_lists = _reorder_logits_indices(
|
|
935
|
+
logprobs_lists, logits_indices_selector)
|
|
912
936
|
else:
|
|
913
937
|
logprobs_lists = None
|
|
914
938
|
|
|
@@ -1315,10 +1339,10 @@ class TPUModelRunner(KVConnectorModelRunnerMixin, LoRAModelRunnerMixin):
|
|
|
1315
1339
|
seq_lens_cpu = seq_lens
|
|
1316
1340
|
|
|
1317
1341
|
(input_ids, positions, block_tables, query_start_loc, seq_lens,
|
|
1318
|
-
logits_indices, request_distribution
|
|
1342
|
+
logits_indices, request_distribution) = device_array(
|
|
1319
1343
|
self.mesh,
|
|
1320
1344
|
(input_ids, positions, block_tables, query_start_loc, seq_lens,
|
|
1321
|
-
logits_indices, request_distribution
|
|
1345
|
+
logits_indices, request_distribution),
|
|
1322
1346
|
sharding=data_parallel_attn_sharding,
|
|
1323
1347
|
)
|
|
1324
1348
|
# Async scheduling: substitute placeholder tokens for DP
|
tpu_inference/utils.py
CHANGED
|
@@ -132,8 +132,8 @@ def pathways_hbm_usage_gb(devices: Any) -> List[Tuple[float, float]]:
|
|
|
132
132
|
hbm_used = defaultdict(int)
|
|
133
133
|
hbm_limit = get_device_hbm_limit()
|
|
134
134
|
for array in live_arrays:
|
|
135
|
-
for buffer in array.
|
|
136
|
-
hbm_used[buffer.device] += buffer.nbytes
|
|
135
|
+
for buffer in array.addressable_shards:
|
|
136
|
+
hbm_used[buffer.data.device] += buffer.data.nbytes
|
|
137
137
|
return [(hbm_used[device], hbm_limit) for device in devices]
|
|
138
138
|
|
|
139
139
|
|
|
@@ -25,7 +25,7 @@ from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
|
|
25
25
|
from tpu_inference import envs, utils
|
|
26
26
|
from tpu_inference.distributed.utils import (get_host_ip, get_kv_transfer_port,
|
|
27
27
|
get_node_id)
|
|
28
|
-
from tpu_inference.layers.
|
|
28
|
+
from tpu_inference.layers.common.sharding import ShardingConfigManager
|
|
29
29
|
from tpu_inference.logger import init_logger
|
|
30
30
|
from tpu_inference.runner.kv_cache import get_rpa_page_size_bytes
|
|
31
31
|
from tpu_inference.runner.tpu_runner import TPUModelRunner
|