tpu-inference 0.11.1.dev202511130813__py3-none-any.whl → 0.11.1.dev202511220812__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tests/test_envs.py +182 -0
- tests/test_utils.py +23 -14
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/core_tpu.py +17 -9
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +2 -3
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +1 -1
- tpu_inference/executors/ray_distributed_executor.py +27 -11
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +110 -64
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +7 -0
- tpu_inference/layers/{jax → common}/attention_interface.py +1 -1
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/jax/attention/attention.py +1 -1
- tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
- tpu_inference/layers/jax/sample/sampling.py +2 -2
- tpu_inference/layers/vllm/attention.py +1 -1
- tpu_inference/layers/vllm/quantization/__init__.py +7 -3
- tpu_inference/layers/vllm/quantization/awq.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -2
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +4 -3
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +12 -11
- tpu_inference/models/jax/llama3.py +4 -3
- tpu_inference/models/jax/llama_eagle3.py +9 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +3 -2
- tpu_inference/models/jax/qwen2_5_vl.py +4 -3
- tpu_inference/models/jax/qwen3.py +3 -2
- tpu_inference/models/jax/utils/weight_utils.py +21 -8
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -10
- tpu_inference/platforms/tpu_platform.py +17 -7
- tpu_inference/runner/compilation_manager.py +37 -17
- tpu_inference/runner/kv_cache.py +1 -1
- tpu_inference/runner/kv_cache_manager.py +8 -2
- tpu_inference/runner/tpu_runner.py +199 -87
- tpu_inference/spec_decode/jax/eagle3.py +2 -1
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +7 -6
- tpu_inference/worker/tpu_worker.py +159 -23
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/METADATA +2 -2
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/RECORD +52 -54
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- /tpu_inference/layers/{jax → common}/binary_search.py +0 -0
- /tpu_inference/layers/{jax → common}/sharding.py +0 -0
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
from typing import Callable, Optional, Union
|
|
2
|
+
|
|
3
|
+
import jax
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
import torch
|
|
6
|
+
from jax.experimental.layout import Format, Layout
|
|
7
|
+
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
8
|
+
from torch.nn.parameter import Parameter
|
|
9
|
+
from torchax.interop import jax_view, torch_view
|
|
10
|
+
from torchax.ops.mappings import t2j
|
|
11
|
+
from vllm.logger import init_logger
|
|
12
|
+
from vllm.model_executor.layers.fused_moe.config import (
|
|
13
|
+
FusedMoEConfig, FusedMoEQuantConfig, biased_moe_quant_config)
|
|
14
|
+
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
|
15
|
+
FusedMoEMethodBase)
|
|
16
|
+
from vllm.model_executor.layers.linear import LinearBase
|
|
17
|
+
from vllm.model_executor.layers.quantization import \
|
|
18
|
+
register_quantization_config
|
|
19
|
+
from vllm.model_executor.layers.quantization.base_config import \
|
|
20
|
+
QuantizeMethodBase
|
|
21
|
+
from vllm.model_executor.layers.quantization.mxfp4 import (Mxfp4Backend,
|
|
22
|
+
Mxfp4Config,
|
|
23
|
+
Mxfp4MoEMethod)
|
|
24
|
+
from vllm.model_executor.layers.quantization.utils.quant_utils import \
|
|
25
|
+
is_layer_skipped
|
|
26
|
+
|
|
27
|
+
from tpu_inference.layers.common.quant_methods import (MXFP4,
|
|
28
|
+
get_tpu_quant_method)
|
|
29
|
+
from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
|
|
30
|
+
from tpu_inference.layers.vllm.linear_common import \
|
|
31
|
+
reorder_concatenated_tensor_for_sharding
|
|
32
|
+
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
33
|
+
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
34
|
+
VllmUnquantizedLinearMethod
|
|
35
|
+
|
|
36
|
+
MXFP4_BLOCK_SIZE = 32
|
|
37
|
+
|
|
38
|
+
P = PartitionSpec
|
|
39
|
+
logger = init_logger(__name__)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# TODO(kyuyeunk): Move these functions into a common utility file.
|
|
43
|
+
def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
|
|
44
|
+
assert u8_packed_e2m1.dtype == jnp.uint8
|
|
45
|
+
e2m1 = jax.lax.bitcast_convert_type(u8_packed_e2m1, jnp.float4_e2m1fn)
|
|
46
|
+
# bitcast creates one more dimension that splits 8 bits into two e2m1.
|
|
47
|
+
# we flatten them with the last dim.
|
|
48
|
+
return jnp.reshape(e2m1, e2m1.shape[:-2] + (-1, ))
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def e8m0_to_fp32(u8: jax.Array) -> jax.Array:
|
|
52
|
+
e8_finfo = jnp.finfo(jnp.float8_e8m0fnu)
|
|
53
|
+
exponents = u8.astype(jnp.int32) + e8_finfo.minexp
|
|
54
|
+
ones = jnp.ones_like(u8, dtype=jnp.float32)
|
|
55
|
+
return jnp.ldexp(ones, exponents)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def dequantize_block_weight(weight: jax.Array,
|
|
59
|
+
scale: jax.Array,
|
|
60
|
+
block_size: int,
|
|
61
|
+
out_dtype: jnp.dtype = jnp.bfloat16) -> jax.Array:
|
|
62
|
+
orig_shape = weight.shape
|
|
63
|
+
weight_block = weight.reshape(orig_shape[:-1] + (-1, block_size))
|
|
64
|
+
weight_dequantized = weight_block.astype(jnp.float32) * jnp.expand_dims(
|
|
65
|
+
scale, -1)
|
|
66
|
+
return weight_dequantized.reshape(orig_shape).astype(out_dtype)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@register_quantization_config(get_tpu_quant_method(MXFP4))
|
|
70
|
+
class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def get_name(cls):
|
|
74
|
+
return MXFP4
|
|
75
|
+
|
|
76
|
+
def get_quant_method(self, layer: torch.nn.Module,
|
|
77
|
+
prefix: str) -> Optional["QuantizeMethodBase"]:
|
|
78
|
+
from vllm.attention.layer import Attention # Avoid circular import
|
|
79
|
+
|
|
80
|
+
if isinstance(layer, LinearBase):
|
|
81
|
+
linear_config = self.get_linear_config(layer)
|
|
82
|
+
if self.ignored_layers and is_layer_skipped(
|
|
83
|
+
prefix=prefix,
|
|
84
|
+
ignored_layers=self.ignored_layers,
|
|
85
|
+
fused_mapping=self.packed_modules_mapping,
|
|
86
|
+
):
|
|
87
|
+
return VllmUnquantizedLinearMethod(linear_config)
|
|
88
|
+
# TODO: Add support for MXFP4 Linear Method.
|
|
89
|
+
# MXFP4 LinearMethod is available in AMD-Quark, refer to that
|
|
90
|
+
# implementation if you are interested in enabling MXFP4 here.
|
|
91
|
+
logger.warning_once(
|
|
92
|
+
"MXFP4 linear layer is not implemented - falling back to "
|
|
93
|
+
"UnquantizedLinearMethod.")
|
|
94
|
+
return VllmUnquantizedLinearMethod(linear_config)
|
|
95
|
+
elif isinstance(layer, FusedMoE):
|
|
96
|
+
return VllmMxfp4MoEMethod(layer.moe_config, self.mesh)
|
|
97
|
+
elif isinstance(layer, Attention):
|
|
98
|
+
# TODO: Add support for MXFP4 Attention.
|
|
99
|
+
logger.warning_once("MXFP4 attention layer is not implemented. "
|
|
100
|
+
"Skipping quantization for this layer.")
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
105
|
+
|
|
106
|
+
def __init__(self, moe: FusedMoEConfig, mesh: Mesh):
|
|
107
|
+
FusedMoEMethodBase.__init__(self, moe)
|
|
108
|
+
|
|
109
|
+
# We piggyback on triton implementation as it applies minimal hardware
|
|
110
|
+
# specific post processing to the weights.
|
|
111
|
+
self.mxfp4_backend = Mxfp4Backend.TRITON
|
|
112
|
+
self.mesh = mesh
|
|
113
|
+
|
|
114
|
+
def get_fused_moe_quant_config(
|
|
115
|
+
self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
|
|
116
|
+
# Because we have dequantized weights, we only need biased moe config.
|
|
117
|
+
# TODO(kyuyeunk): Add native support for MXFP4.
|
|
118
|
+
return biased_moe_quant_config(
|
|
119
|
+
layer.w13_bias,
|
|
120
|
+
layer.w2_bias,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
def process_weights_after_loading(self, layer: torch.nn.Module):
|
|
124
|
+
assert isinstance(layer, FusedMoE)
|
|
125
|
+
|
|
126
|
+
w13_weight = u8_unpack_e2m1(t2j(layer.w13_weight, use_dlpack=False))
|
|
127
|
+
w13_weight_scale = e8m0_to_fp32(
|
|
128
|
+
t2j(layer.w13_weight_scale, use_dlpack=False))
|
|
129
|
+
w13_bias = t2j(layer.w13_bias, use_dlpack=False)
|
|
130
|
+
|
|
131
|
+
w2_weight = u8_unpack_e2m1(t2j(layer.w2_weight, use_dlpack=False))
|
|
132
|
+
w2_weight_scale = e8m0_to_fp32(
|
|
133
|
+
t2j(layer.w2_weight_scale, use_dlpack=False))
|
|
134
|
+
w2_bias = t2j(layer.w2_bias, use_dlpack=False)
|
|
135
|
+
|
|
136
|
+
# We dequantize fp4 weights into bf16.
|
|
137
|
+
# TODO(kyuyeunk): Add native support for MXFP4.
|
|
138
|
+
w13_weight = dequantize_block_weight(w13_weight, w13_weight_scale,
|
|
139
|
+
MXFP4_BLOCK_SIZE, jnp.bfloat16)
|
|
140
|
+
w2_weight = dequantize_block_weight(w2_weight, w2_weight_scale,
|
|
141
|
+
MXFP4_BLOCK_SIZE, jnp.bfloat16)
|
|
142
|
+
|
|
143
|
+
# Because we have dequantized weights, scales are not used anymore.
|
|
144
|
+
delattr(layer, "w13_weight_scale")
|
|
145
|
+
delattr(layer, "w2_weight_scale")
|
|
146
|
+
|
|
147
|
+
if layer.activation == "swigluoai":
|
|
148
|
+
# When using swigluoai, vLLM splits gmm output in a interleaved way.
|
|
149
|
+
# However, interleaved split is not performant on TPU. Therefore,
|
|
150
|
+
# we preprocess the weight so that splitting gmm output by middle
|
|
151
|
+
# can still get the same result.
|
|
152
|
+
w1_weight = w13_weight[:, ::2, :]
|
|
153
|
+
w3_weight = w13_weight[:, 1::2, :]
|
|
154
|
+
w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
|
|
155
|
+
|
|
156
|
+
w1_bias = w13_bias[:, ::2]
|
|
157
|
+
w3_bias = w13_bias[:, 1::2]
|
|
158
|
+
w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
|
|
159
|
+
|
|
160
|
+
# TODO(kyuyeunk): Add weight processing logic for the new kernel.
|
|
161
|
+
if layer.use_ep:
|
|
162
|
+
w13_weight = jax.device_put(
|
|
163
|
+
w13_weight,
|
|
164
|
+
Format(Layout((0, 1, 2)),
|
|
165
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
166
|
+
w2_weight = jax.device_put(
|
|
167
|
+
w2_weight,
|
|
168
|
+
Format(Layout((0, 1, 2)),
|
|
169
|
+
NamedSharding(self.mesh, P("model", None, None))))
|
|
170
|
+
|
|
171
|
+
w13_bias = jax.device_put(
|
|
172
|
+
w13_bias,
|
|
173
|
+
Format(Layout((0, 1)),
|
|
174
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
175
|
+
w2_bias = jax.device_put(
|
|
176
|
+
w2_bias,
|
|
177
|
+
Format(Layout((0, 1)),
|
|
178
|
+
NamedSharding(self.mesh, P("model", None))))
|
|
179
|
+
|
|
180
|
+
else:
|
|
181
|
+
intermediate_size = w13_weight.shape[1] // 2
|
|
182
|
+
assert intermediate_size == w2_weight.shape[-1]
|
|
183
|
+
output_sizes = [intermediate_size, intermediate_size]
|
|
184
|
+
n_shards = self.mesh.shape["model"]
|
|
185
|
+
assert intermediate_size % n_shards == 0
|
|
186
|
+
w13_weight = reorder_concatenated_tensor_for_sharding(w13_weight,
|
|
187
|
+
output_sizes,
|
|
188
|
+
n_shards,
|
|
189
|
+
dim=1)
|
|
190
|
+
w13_weight = jax.device_put(
|
|
191
|
+
w13_weight,
|
|
192
|
+
Format(Layout((0, 1, 2)),
|
|
193
|
+
NamedSharding(self.mesh, P(None, "model", None))))
|
|
194
|
+
w2_weight = jax.device_put(
|
|
195
|
+
w2_weight,
|
|
196
|
+
Format(Layout((0, 1, 2)),
|
|
197
|
+
NamedSharding(self.mesh, P(None, None, "model"))))
|
|
198
|
+
|
|
199
|
+
w13_bias = reorder_concatenated_tensor_for_sharding(w13_bias,
|
|
200
|
+
output_sizes,
|
|
201
|
+
n_shards,
|
|
202
|
+
dim=1)
|
|
203
|
+
w13_bias = jax.device_put(
|
|
204
|
+
w13_bias,
|
|
205
|
+
Format(Layout((0, 1)),
|
|
206
|
+
NamedSharding(self.mesh, P(None, "model"))))
|
|
207
|
+
w2_bias = jax.device_put(
|
|
208
|
+
w2_bias,
|
|
209
|
+
Format(Layout((0, 1)), NamedSharding(self.mesh, P(None,
|
|
210
|
+
None))))
|
|
211
|
+
|
|
212
|
+
layer.w13_weight = Parameter(torch_view(w13_weight),
|
|
213
|
+
requires_grad=False)
|
|
214
|
+
layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
|
|
215
|
+
|
|
216
|
+
layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
|
|
217
|
+
layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
|
|
218
|
+
|
|
219
|
+
pass
|
|
220
|
+
|
|
221
|
+
def apply(
|
|
222
|
+
self,
|
|
223
|
+
layer: torch.nn.Module,
|
|
224
|
+
x: torch.Tensor,
|
|
225
|
+
router_logits: torch.Tensor,
|
|
226
|
+
top_k: int,
|
|
227
|
+
renormalize: bool,
|
|
228
|
+
use_grouped_topk: bool = False,
|
|
229
|
+
topk_group: Optional[int] = None,
|
|
230
|
+
num_expert_group: Optional[int] = None,
|
|
231
|
+
global_num_experts: int = -1,
|
|
232
|
+
expert_map: Optional[torch.Tensor] = None,
|
|
233
|
+
custom_routing_function: Optional[Callable] = None,
|
|
234
|
+
scoring_func: str = "softmax",
|
|
235
|
+
routed_scaling_factor: float = 1.0,
|
|
236
|
+
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
237
|
+
apply_router_weight_on_input: bool = False,
|
|
238
|
+
activation: str = "silu",
|
|
239
|
+
enable_eplb: bool = False,
|
|
240
|
+
expert_load_view: Optional[torch.Tensor] = None,
|
|
241
|
+
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
242
|
+
logical_replica_count: Optional[torch.Tensor] = None,
|
|
243
|
+
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
244
|
+
assert isinstance(layer, FusedMoE)
|
|
245
|
+
if scoring_func != "softmax":
|
|
246
|
+
raise NotImplementedError(
|
|
247
|
+
"Only softmax is supported for scoring_func")
|
|
248
|
+
|
|
249
|
+
# Use the original implementation
|
|
250
|
+
output = fused_moe_func_padded(
|
|
251
|
+
jax_view(x),
|
|
252
|
+
jax_view(layer.w13_weight),
|
|
253
|
+
jax_view(layer.w2_weight),
|
|
254
|
+
jax_view(layer.w13_bias) if self.moe.has_bias else None,
|
|
255
|
+
jax_view(layer.w2_bias) if self.moe.has_bias else None,
|
|
256
|
+
jax_view(router_logits),
|
|
257
|
+
topk=top_k,
|
|
258
|
+
global_num_experts=global_num_experts,
|
|
259
|
+
renormalize=renormalize,
|
|
260
|
+
reduce_results=layer.reduce_results,
|
|
261
|
+
mesh=self.mesh,
|
|
262
|
+
use_ep=layer.use_ep,
|
|
263
|
+
activation=activation,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
return torch_view(output)
|
|
@@ -23,6 +23,8 @@ from vllm.model_executor.layers.quantization.base_config import (
|
|
|
23
23
|
|
|
24
24
|
from tpu_inference import envs
|
|
25
25
|
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
|
|
26
|
+
from tpu_inference.layers.common.quant_methods import (UNQUANTIZED,
|
|
27
|
+
get_tpu_quant_method)
|
|
26
28
|
from tpu_inference.layers.vllm.fused_moe import fused_moe_func_padded
|
|
27
29
|
from tpu_inference.layers.vllm.linear_common import (
|
|
28
30
|
reorder_concatenated_tensor_for_sharding,
|
|
@@ -34,12 +36,12 @@ P = PartitionSpec
|
|
|
34
36
|
logger = init_logger(__name__)
|
|
35
37
|
|
|
36
38
|
|
|
37
|
-
@register_quantization_config(
|
|
39
|
+
@register_quantization_config(get_tpu_quant_method(UNQUANTIZED))
|
|
38
40
|
class VllmUnquantizedConfig(QuantizationConfig, JaxCommonConfig):
|
|
39
41
|
|
|
40
42
|
@classmethod
|
|
41
43
|
def get_name(cls) -> str:
|
|
42
|
-
return
|
|
44
|
+
return UNQUANTIZED
|
|
43
45
|
|
|
44
46
|
@classmethod
|
|
45
47
|
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
|
@@ -189,7 +191,6 @@ class VllmUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|
|
189
191
|
|
|
190
192
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
191
193
|
assert isinstance(layer, FusedMoE)
|
|
192
|
-
|
|
193
194
|
w13_weight = t2j(layer.w13_weight, use_dlpack=False)
|
|
194
195
|
w2_weight = t2j(layer.w2_weight, use_dlpack=False)
|
|
195
196
|
|
|
@@ -19,6 +19,7 @@ from vllm.lora.layers.base_linear import BaseLinearLayerWithLoRA
|
|
|
19
19
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
20
20
|
ParallelLMHead, VocabParallelEmbedding)
|
|
21
21
|
|
|
22
|
+
from tpu_inference import envs
|
|
22
23
|
from tpu_inference.logger import init_logger
|
|
23
24
|
|
|
24
25
|
P = PartitionSpec
|
|
@@ -211,8 +212,7 @@ def _shard_module_to_tpu(model: torch.nn.Module, mesh: Mesh) -> None:
|
|
|
211
212
|
def _sharded_device_put(tensor: jax.Array, sharding) -> jax.Array:
|
|
212
213
|
if isinstance(tensor, tuple):
|
|
213
214
|
return tuple(_sharded_device_put(t, sharding) for t in tensor)
|
|
214
|
-
|
|
215
|
-
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
|
|
215
|
+
multihost_backend = envs.TPU_MULTIHOST_BACKEND
|
|
216
216
|
if multihost_backend != "ray":
|
|
217
217
|
return jax.device_put(tensor, sharding)
|
|
218
218
|
|
|
@@ -239,7 +239,6 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|
|
239
239
|
lora_index_to_id: list[Optional[int]],
|
|
240
240
|
max_loras: int,
|
|
241
241
|
vocab_size: int,
|
|
242
|
-
extra_vocab_size: int,
|
|
243
242
|
):
|
|
244
243
|
# Pad the prompt mapping to avoid running into recompiles on the TPU
|
|
245
244
|
# TODO: Should this happen inside mapping internally? If so how can we
|
|
@@ -258,7 +257,7 @@ class PunicaWrapperTPU(PunicaWrapperBase):
|
|
|
258
257
|
lora_index_to_id,
|
|
259
258
|
max_loras,
|
|
260
259
|
vocab_size,
|
|
261
|
-
extra_vocab_size
|
|
260
|
+
0, # extra_vocab_size
|
|
262
261
|
"cpu",
|
|
263
262
|
)
|
|
264
263
|
with torchax.default_env():
|
|
@@ -11,7 +11,7 @@ from vllm.config import VllmConfig
|
|
|
11
11
|
from vllm.utils.func_utils import supports_kw
|
|
12
12
|
|
|
13
13
|
from tpu_inference import envs
|
|
14
|
-
from tpu_inference.layers.
|
|
14
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
15
15
|
from tpu_inference.logger import init_logger
|
|
16
16
|
from tpu_inference.models.jax.utils.quantization.quantization_utils import (
|
|
17
17
|
apply_qwix_on_abstract_model, apply_qwix_quantization,
|
|
@@ -36,19 +36,17 @@ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
|
|
|
36
36
|
from tpu_inference.models.jax.llama3 import LlamaForCausalLM
|
|
37
37
|
from tpu_inference.models.jax.llama4 import Llama4ForCausalLM
|
|
38
38
|
from tpu_inference.models.jax.llama_eagle3 import EagleLlama3ForCausalLM
|
|
39
|
-
from tpu_inference.models.jax.
|
|
40
|
-
from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
|
|
39
|
+
from tpu_inference.models.jax.llama_guard_4 import LlamaGuard4ForCausalLM
|
|
41
40
|
from tpu_inference.models.jax.qwen2_5_vl import \
|
|
42
41
|
Qwen2_5_VLForConditionalGeneration
|
|
43
42
|
from tpu_inference.models.jax.qwen3 import Qwen3ForCausalLM
|
|
44
43
|
_MODEL_REGISTRY["Llama4ForCausalLM"] = Llama4ForCausalLM
|
|
45
44
|
_MODEL_REGISTRY["DeepseekV3ForCausalLM"] = DeepSeekV3
|
|
46
45
|
_MODEL_REGISTRY["LlamaForCausalLM"] = LlamaForCausalLM
|
|
47
|
-
_MODEL_REGISTRY["
|
|
46
|
+
_MODEL_REGISTRY["Llama4ForConditionalGeneration"] = LlamaGuard4ForCausalLM
|
|
48
47
|
_MODEL_REGISTRY["Qwen3ForCausalLM"] = Qwen3ForCausalLM
|
|
49
48
|
_MODEL_REGISTRY[
|
|
50
49
|
"Qwen2_5_VLForConditionalGeneration"] = Qwen2_5_VLForConditionalGeneration
|
|
51
|
-
_MODEL_REGISTRY["Phi3ForCausalLM"] = Phi3ForCausalLM
|
|
52
50
|
_MODEL_REGISTRY["Eagle3LlamaForCausalLM"] = EagleLlama3ForCausalLM
|
|
53
51
|
_MODEL_REGISTRY["GptOssForCausalLM"] = GptOss
|
|
54
52
|
|
|
@@ -57,8 +55,10 @@ def _get_model_architecture(config: PretrainedConfig) -> nnx.Module:
|
|
|
57
55
|
if arch in _MODEL_REGISTRY:
|
|
58
56
|
return _MODEL_REGISTRY[arch]
|
|
59
57
|
raise UnsupportedArchitectureError(
|
|
60
|
-
f"Model architectures {architectures}
|
|
61
|
-
|
|
58
|
+
f"Model architectures {architectures} not "
|
|
59
|
+
"registered in tpu-inference. Falling back to vLLM-native "
|
|
60
|
+
f"Pytorch definition. JAX-native architectures: {list(_MODEL_REGISTRY.keys())}"
|
|
61
|
+
)
|
|
62
62
|
|
|
63
63
|
|
|
64
64
|
def _get_nnx_model(
|
|
@@ -217,7 +217,7 @@ def get_flax_model(
|
|
|
217
217
|
hidden_states_sharding, # aux hidden states
|
|
218
218
|
),
|
|
219
219
|
donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
|
|
220
|
-
static_argnums=
|
|
220
|
+
static_argnums=7, #7 is layer_name_to_kvcache_index
|
|
221
221
|
)
|
|
222
222
|
def run_model(graphdef, state, *args):
|
|
223
223
|
model = nnx.merge(graphdef, state)
|
|
@@ -242,10 +242,11 @@ def get_flax_model(
|
|
|
242
242
|
model = nnx.merge(graphdef, state)
|
|
243
243
|
return model.get_multimodal_embeddings(image_grid_thw, **kwargs)
|
|
244
244
|
|
|
245
|
+
embed_sharding = NamedSharding(mesh, PartitionSpec(None))
|
|
245
246
|
# This function will calculates the embeddings of input texts and then merge with the image embeddings
|
|
246
247
|
@functools.partial(
|
|
247
248
|
jax.jit,
|
|
248
|
-
out_shardings=(
|
|
249
|
+
out_shardings=(embed_sharding),
|
|
249
250
|
)
|
|
250
251
|
def run_get_input_embeddings(graphdef, state, *args, **kwargs):
|
|
251
252
|
model = nnx.merge(graphdef, state)
|
|
@@ -325,8 +326,8 @@ def get_model(
|
|
|
325
326
|
# Convert the error message to a string to check its contents
|
|
326
327
|
error_msg = str(e)
|
|
327
328
|
|
|
328
|
-
logger.warning(
|
|
329
|
-
|
|
329
|
+
logger.warning(error_msg)
|
|
330
|
+
|
|
330
331
|
# Fall back to the vLLM model and updating the dtype accordingly
|
|
331
332
|
vllm_config.model_config.dtype = j2t_dtype(
|
|
332
333
|
vllm_config.model_config.dtype.dtype)
|
|
@@ -8,10 +8,10 @@ from transformers import LlamaConfig, modeling_flax_utils
|
|
|
8
8
|
from vllm.config import VllmConfig
|
|
9
9
|
|
|
10
10
|
from tpu_inference import utils
|
|
11
|
+
from tpu_inference.layers.common.attention_interface import attention
|
|
11
12
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
12
|
-
from tpu_inference.layers.
|
|
13
|
+
from tpu_inference.layers.common.sharding import ShardingAxisName
|
|
13
14
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
14
|
-
from tpu_inference.layers.jax.sharding import ShardingAxisName
|
|
15
15
|
from tpu_inference.logger import init_logger
|
|
16
16
|
from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
|
|
17
17
|
load_hf_weights)
|
|
@@ -368,7 +368,8 @@ class LlamaForCausalLM(nnx.Module):
|
|
|
368
368
|
"lm_head": "model.lm_head",
|
|
369
369
|
})
|
|
370
370
|
|
|
371
|
-
metadata_map = get_default_maps(self.vllm_config
|
|
371
|
+
metadata_map = get_default_maps(self.vllm_config.model_config,
|
|
372
|
+
self.mesh, mappings)
|
|
372
373
|
load_hf_weights(vllm_config=self.vllm_config,
|
|
373
374
|
model=self,
|
|
374
375
|
metadata_map=metadata_map,
|
|
@@ -194,13 +194,12 @@ class Eagle3LlamaModel(nnx.Module):
|
|
|
194
194
|
|
|
195
195
|
def update_reshape_map_for_eagle3(vllm_config: VllmConfig,
|
|
196
196
|
metadata_map: MetadataMap):
|
|
197
|
-
model_config = vllm_config.
|
|
197
|
+
model_config = vllm_config.speculative_config.draft_model_config
|
|
198
198
|
hf_config = model_config.hf_config
|
|
199
199
|
|
|
200
200
|
num_heads = hf_config.num_attention_heads
|
|
201
201
|
num_kv_heads = hf_config.num_key_value_heads
|
|
202
|
-
hidden_size =
|
|
203
|
-
|
|
202
|
+
hidden_size = hf_config.hidden_size
|
|
204
203
|
head_dim_original = model_config.get_head_size()
|
|
205
204
|
|
|
206
205
|
metadata_map.reshape_map.update({
|
|
@@ -312,7 +311,11 @@ class EagleLlama3ForCausalLM(nnx.Module):
|
|
|
312
311
|
r".*d2t.*",
|
|
313
312
|
]
|
|
314
313
|
|
|
315
|
-
|
|
314
|
+
# `embed_tokens` is shared between target and draft.
|
|
315
|
+
exclude_regex = [r".*embed_tokens.*"]
|
|
316
|
+
metadata_map = get_default_maps(
|
|
317
|
+
self.vllm_config.speculative_config.draft_model_config, self.mesh,
|
|
318
|
+
mappings)
|
|
316
319
|
|
|
317
320
|
update_reshape_map_for_eagle3(self.vllm_config, metadata_map)
|
|
318
321
|
|
|
@@ -322,7 +325,8 @@ class EagleLlama3ForCausalLM(nnx.Module):
|
|
|
322
325
|
metadata_map=metadata_map,
|
|
323
326
|
mesh=self.mesh,
|
|
324
327
|
is_draft_model=True,
|
|
325
|
-
keep_original_dtype_keys_regex=keep_original_dtype_keys_regex
|
|
328
|
+
keep_original_dtype_keys_regex=keep_original_dtype_keys_regex,
|
|
329
|
+
exclude_regex=exclude_regex if exclude_regex else None)
|
|
326
330
|
|
|
327
331
|
# If the embedding is not initialized, initialize it with a dummpy array here to pass jit compilation. The real weights will be shared from the target model in eagle3 class.
|
|
328
332
|
if isinstance(self.model.embed_tokens.embedding.value,
|