tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511130813__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +34 -303
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
- tests/lora/test_layers.py +6 -0
- tests/lora/utils.py +8 -0
- tests/test_utils.py +16 -24
- tpu_inference/__init__.py +3 -22
- tpu_inference/core/core_tpu.py +9 -17
- tpu_inference/core/disagg_utils.py +8 -6
- tpu_inference/distributed/tpu_connector.py +4 -3
- tpu_inference/distributed/utils.py +2 -3
- tpu_inference/envs.py +8 -61
- tpu_inference/executors/ray_distributed_executor.py +11 -31
- tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +143 -287
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -7
- tpu_inference/layers/jax/attention/attention.py +1 -1
- tpu_inference/layers/{common → jax}/attention_interface.py +2 -8
- tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
- tpu_inference/layers/jax/sample/sampling.py +2 -2
- tpu_inference/layers/{common → jax}/sharding.py +5 -5
- tpu_inference/layers/vllm/attention.py +1 -1
- tpu_inference/layers/vllm/fused_moe.py +208 -170
- tpu_inference/layers/vllm/quantization/__init__.py +3 -7
- tpu_inference/layers/vllm/quantization/awq.py +3 -4
- tpu_inference/layers/vllm/quantization/common.py +1 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +2 -4
- tpu_inference/layers/vllm/quantization/unquantized.py +67 -62
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +2 -1
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +28 -0
- tpu_inference/mock/vllm_envs.py +1219 -0
- tpu_inference/mock/vllm_logger.py +212 -0
- tpu_inference/mock/vllm_logging_utils.py +15 -0
- tpu_inference/models/common/model_loader.py +12 -46
- tpu_inference/models/jax/llama3.py +3 -4
- tpu_inference/models/jax/llama_eagle3.py +5 -8
- tpu_inference/models/jax/phi3.py +376 -0
- tpu_inference/models/jax/qwen2.py +2 -3
- tpu_inference/models/jax/qwen2_5_vl.py +50 -165
- tpu_inference/models/jax/qwen3.py +2 -3
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
- tpu_inference/models/jax/utils/weight_utils.py +143 -198
- tpu_inference/models/vllm/vllm_model_wrapper.py +14 -32
- tpu_inference/platforms/tpu_platform.py +34 -47
- tpu_inference/runner/compilation_manager.py +60 -145
- tpu_inference/runner/kv_cache.py +2 -2
- tpu_inference/runner/kv_cache_manager.py +18 -17
- tpu_inference/runner/persistent_batch_manager.py +2 -40
- tpu_inference/runner/structured_decoding_manager.py +3 -2
- tpu_inference/runner/tpu_runner.py +135 -283
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +21 -71
- tpu_inference/tpu_info.py +3 -4
- tpu_inference/utils.py +15 -38
- tpu_inference/worker/tpu_worker.py +26 -163
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/METADATA +3 -4
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/RECORD +63 -61
- tests/test_envs.py +0 -203
- tpu_inference/layers/common/quant_methods.py +0 -8
- tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
- tpu_inference/models/jax/llama_guard_4.py +0 -361
- /tpu_inference/layers/{common → jax}/binary_search.py +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/WHEEL +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/top_level.txt +0 -0
|
@@ -1,331 +0,0 @@
|
|
|
1
|
-
from typing import Callable, Optional, Union
|
|
2
|
-
|
|
3
|
-
import jax
|
|
4
|
-
import jax.numpy as jnp
|
|
5
|
-
import torch
|
|
6
|
-
from jax.experimental.layout import Format, Layout
|
|
7
|
-
from jax.sharding import Mesh, NamedSharding, PartitionSpec
|
|
8
|
-
from torch.nn.parameter import Parameter
|
|
9
|
-
from torchax.interop import jax_view, torch_view
|
|
10
|
-
from torchax.ops.mappings import t2j
|
|
11
|
-
from vllm.logger import init_logger
|
|
12
|
-
from vllm.model_executor.layers.fused_moe.config import (
|
|
13
|
-
FusedMoEConfig, FusedMoEQuantConfig, biased_moe_quant_config)
|
|
14
|
-
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
|
15
|
-
FusedMoEMethodBase)
|
|
16
|
-
from vllm.model_executor.layers.linear import LinearBase
|
|
17
|
-
from vllm.model_executor.layers.quantization import \
|
|
18
|
-
register_quantization_config
|
|
19
|
-
from vllm.model_executor.layers.quantization.base_config import \
|
|
20
|
-
QuantizeMethodBase
|
|
21
|
-
from vllm.model_executor.layers.quantization.mxfp4 import (Mxfp4Backend,
|
|
22
|
-
Mxfp4Config,
|
|
23
|
-
Mxfp4MoEMethod)
|
|
24
|
-
from vllm.model_executor.layers.quantization.utils.quant_utils import \
|
|
25
|
-
is_layer_skipped
|
|
26
|
-
|
|
27
|
-
from tpu_inference import envs
|
|
28
|
-
from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe
|
|
29
|
-
from tpu_inference.layers.common.quant_methods import (MXFP4,
|
|
30
|
-
get_tpu_quant_method)
|
|
31
|
-
from tpu_inference.layers.vllm.fused_moe import fused_moe_func
|
|
32
|
-
from tpu_inference.layers.vllm.linear_common import \
|
|
33
|
-
reorder_concatenated_tensor_for_sharding
|
|
34
|
-
from tpu_inference.layers.vllm.quantization.common import JaxCommonConfig
|
|
35
|
-
from tpu_inference.layers.vllm.quantization.unquantized import \
|
|
36
|
-
VllmUnquantizedLinearMethod
|
|
37
|
-
|
|
38
|
-
MXFP4_BLOCK_SIZE = 32
|
|
39
|
-
|
|
40
|
-
P = PartitionSpec
|
|
41
|
-
logger = init_logger(__name__)
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
# TODO(kyuyeunk): Move these functions into a common utility file.
|
|
45
|
-
def u8_unpack_e2m1(u8_packed_e2m1: jax.Array) -> jax.Array:
|
|
46
|
-
assert u8_packed_e2m1.dtype == jnp.uint8
|
|
47
|
-
e2m1 = jax.lax.bitcast_convert_type(u8_packed_e2m1, jnp.float4_e2m1fn)
|
|
48
|
-
# bitcast creates one more dimension that splits 8 bits into two e2m1.
|
|
49
|
-
# we flatten them with the last dim.
|
|
50
|
-
return jnp.reshape(e2m1, e2m1.shape[:-2] + (-1, ))
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
def e8m0_to_fp32(u8: jax.Array) -> jax.Array:
|
|
54
|
-
e8_finfo = jnp.finfo(jnp.float8_e8m0fnu)
|
|
55
|
-
exponents = u8.astype(jnp.int32) + e8_finfo.minexp
|
|
56
|
-
ones = jnp.ones_like(u8, dtype=jnp.float32)
|
|
57
|
-
return jnp.ldexp(ones, exponents)
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
def dequantize_block_weight(weight: jax.Array,
|
|
61
|
-
scale: jax.Array,
|
|
62
|
-
block_size: int,
|
|
63
|
-
out_dtype: jnp.dtype = jnp.bfloat16) -> jax.Array:
|
|
64
|
-
orig_shape = weight.shape
|
|
65
|
-
weight_block = weight.reshape(orig_shape[:-1] + (-1, block_size))
|
|
66
|
-
weight_dequantized = weight_block.astype(jnp.float32) * jnp.expand_dims(
|
|
67
|
-
scale, -1)
|
|
68
|
-
return weight_dequantized.reshape(orig_shape).astype(out_dtype)
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
@register_quantization_config(get_tpu_quant_method(MXFP4))
|
|
72
|
-
class VllmMxfp4Config(Mxfp4Config, JaxCommonConfig):
|
|
73
|
-
|
|
74
|
-
@classmethod
|
|
75
|
-
def get_name(cls):
|
|
76
|
-
return MXFP4
|
|
77
|
-
|
|
78
|
-
def get_quant_method(self, layer: torch.nn.Module,
|
|
79
|
-
prefix: str) -> Optional["QuantizeMethodBase"]:
|
|
80
|
-
from vllm.attention.layer import Attention # Avoid circular import
|
|
81
|
-
|
|
82
|
-
if isinstance(layer, LinearBase):
|
|
83
|
-
linear_config = self.get_linear_config(layer)
|
|
84
|
-
if self.ignored_layers and is_layer_skipped(
|
|
85
|
-
prefix=prefix,
|
|
86
|
-
ignored_layers=self.ignored_layers,
|
|
87
|
-
fused_mapping=self.packed_modules_mapping,
|
|
88
|
-
):
|
|
89
|
-
return VllmUnquantizedLinearMethod(linear_config)
|
|
90
|
-
logger.warning_once(
|
|
91
|
-
"MXFP4 linear layer is not implemented - falling back to "
|
|
92
|
-
"UnquantizedLinearMethod.")
|
|
93
|
-
return VllmUnquantizedLinearMethod(linear_config)
|
|
94
|
-
elif isinstance(layer, FusedMoE):
|
|
95
|
-
moe_config = self.get_moe_config(layer)
|
|
96
|
-
return VllmMxfp4MoEMethod(moe_config, self.mesh)
|
|
97
|
-
elif isinstance(layer, Attention):
|
|
98
|
-
logger.warning_once("MXFP4 attention layer is not implemented. "
|
|
99
|
-
"Skipping quantization for this layer.")
|
|
100
|
-
return None
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
class VllmMxfp4MoEMethod(Mxfp4MoEMethod):
|
|
104
|
-
|
|
105
|
-
def __init__(self,
|
|
106
|
-
moe: FusedMoEConfig,
|
|
107
|
-
mesh: Mesh,
|
|
108
|
-
ep_axis_name: str = 'model'):
|
|
109
|
-
FusedMoEMethodBase.__init__(self, moe)
|
|
110
|
-
|
|
111
|
-
# We piggyback on triton implementation as it applies minimal hardware
|
|
112
|
-
# specific post processing to the weights.
|
|
113
|
-
self.mxfp4_backend = Mxfp4Backend.TRITON
|
|
114
|
-
|
|
115
|
-
self.mesh = mesh
|
|
116
|
-
self.use_kernel = envs.USE_MOE_EP_KERNEL and moe.use_ep
|
|
117
|
-
self.ep_axis_name = ep_axis_name
|
|
118
|
-
# TODO: Use autotune table once we have it.
|
|
119
|
-
self.block_size = {
|
|
120
|
-
"bt": 64,
|
|
121
|
-
"bf": 1024,
|
|
122
|
-
"bd1": 1536,
|
|
123
|
-
"bd2": 1536,
|
|
124
|
-
"btc": 64,
|
|
125
|
-
"bfc": 1024,
|
|
126
|
-
"bd1c": 1536,
|
|
127
|
-
"bd2c": 1536,
|
|
128
|
-
}
|
|
129
|
-
|
|
130
|
-
def get_fused_moe_quant_config(
|
|
131
|
-
self, layer: torch.nn.Module) -> FusedMoEQuantConfig | None:
|
|
132
|
-
# Because we have dequantized weights, we only need biased moe config.
|
|
133
|
-
# TODO(kyuyeunk): Add native support for MXFP4.
|
|
134
|
-
return biased_moe_quant_config(
|
|
135
|
-
layer.w13_bias,
|
|
136
|
-
layer.w2_bias,
|
|
137
|
-
)
|
|
138
|
-
|
|
139
|
-
def process_weights_after_loading(self, layer: torch.nn.Module):
|
|
140
|
-
assert isinstance(layer, FusedMoE)
|
|
141
|
-
assert layer.moe_config.has_bias, "mxfp4 quantization alwyas use bias."
|
|
142
|
-
|
|
143
|
-
w13_weight = u8_unpack_e2m1(t2j(layer.w13_weight, use_dlpack=False))
|
|
144
|
-
w13_weight_scale = e8m0_to_fp32(
|
|
145
|
-
t2j(layer.w13_weight_scale, use_dlpack=False))
|
|
146
|
-
w13_bias = t2j(layer.w13_bias, use_dlpack=False)
|
|
147
|
-
|
|
148
|
-
w2_weight = u8_unpack_e2m1(t2j(layer.w2_weight, use_dlpack=False))
|
|
149
|
-
w2_weight_scale = e8m0_to_fp32(
|
|
150
|
-
t2j(layer.w2_weight_scale, use_dlpack=False))
|
|
151
|
-
w2_bias = t2j(layer.w2_bias, use_dlpack=False)
|
|
152
|
-
|
|
153
|
-
# We dequantize fp4 weights into bf16.
|
|
154
|
-
# TODO(kyuyeunk): Add native support for MXFP4.
|
|
155
|
-
w13_weight = dequantize_block_weight(w13_weight, w13_weight_scale,
|
|
156
|
-
MXFP4_BLOCK_SIZE, jnp.bfloat16)
|
|
157
|
-
w2_weight = dequantize_block_weight(w2_weight, w2_weight_scale,
|
|
158
|
-
MXFP4_BLOCK_SIZE, jnp.bfloat16)
|
|
159
|
-
|
|
160
|
-
num_experts, hidden_size, intermediate_size = w2_weight.shape
|
|
161
|
-
|
|
162
|
-
# Because we have dequantized weights, scales are not used anymore.
|
|
163
|
-
delattr(layer, "w13_weight_scale")
|
|
164
|
-
delattr(layer, "w2_weight_scale")
|
|
165
|
-
|
|
166
|
-
if layer.activation == "swigluoai":
|
|
167
|
-
# When using swigluoai, vLLM splits gmm output in a interleaved way.
|
|
168
|
-
# However, interleaved split is not performant on TPU. Therefore,
|
|
169
|
-
# we preprocess the weight so that splitting gmm output by middle
|
|
170
|
-
# can still get the same result.
|
|
171
|
-
w1_weight = w13_weight[:, ::2, :]
|
|
172
|
-
w3_weight = w13_weight[:, 1::2, :]
|
|
173
|
-
w13_weight = jnp.concat([w1_weight, w3_weight], axis=1)
|
|
174
|
-
|
|
175
|
-
w1_bias = w13_bias[:, ::2]
|
|
176
|
-
w3_bias = w13_bias[:, 1::2]
|
|
177
|
-
w13_bias = jnp.concat([w1_bias, w3_bias], axis=1)
|
|
178
|
-
|
|
179
|
-
if self.use_kernel:
|
|
180
|
-
# Kernel expects:
|
|
181
|
-
# w13: (num_experts, 2, hidden_size, intermediate_size)
|
|
182
|
-
# w2: (num_experts, intermediate_size, hidden_size)
|
|
183
|
-
# Current format:
|
|
184
|
-
# w13_weight: (num_experts, 2*intermediate_size, hidden_size)
|
|
185
|
-
# w2_weight: (num_experts, hidden_size, intermediate_size)
|
|
186
|
-
|
|
187
|
-
w13_reshaped = w13_weight.reshape(num_experts, 2,
|
|
188
|
-
intermediate_size, hidden_size)
|
|
189
|
-
|
|
190
|
-
# Transpose non-constracting dim to right most dim
|
|
191
|
-
w13_weight_transposed = jnp.swapaxes(w13_reshaped, 2, 3)
|
|
192
|
-
w2_weight_transposed = jnp.swapaxes(w2_weight, 1, 2)
|
|
193
|
-
|
|
194
|
-
# Apply EP sharding
|
|
195
|
-
ep_sharding = NamedSharding(self.mesh, P("model"))
|
|
196
|
-
|
|
197
|
-
w13_weight = jax.device_put(
|
|
198
|
-
w13_weight_transposed, Format(Layout((0, 1, 2, 3)),
|
|
199
|
-
ep_sharding))
|
|
200
|
-
w2_weight = jax.device_put(w2_weight_transposed,
|
|
201
|
-
Format(Layout((0, 1, 2)), ep_sharding))
|
|
202
|
-
|
|
203
|
-
w13_bias = w13_bias.reshape(num_experts, 2, intermediate_size)
|
|
204
|
-
w13_bias = jax.device_put(w13_bias,
|
|
205
|
-
Format(Layout((0, 1, 2)), ep_sharding))
|
|
206
|
-
w2_bias = jax.device_put(w2_bias,
|
|
207
|
-
Format(Layout((0, 1)), ep_sharding))
|
|
208
|
-
|
|
209
|
-
else:
|
|
210
|
-
if layer.use_ep:
|
|
211
|
-
ep_sharding = NamedSharding(self.mesh, P("model"))
|
|
212
|
-
w13_weight = jax.device_put(
|
|
213
|
-
w13_weight, Format(Layout((0, 1, 2)), ep_sharding))
|
|
214
|
-
w2_weight = jax.device_put(
|
|
215
|
-
w2_weight, Format(Layout((0, 1, 2)), ep_sharding))
|
|
216
|
-
|
|
217
|
-
w13_bias = jax.device_put(w13_bias,
|
|
218
|
-
Format(Layout((0, 1)), ep_sharding))
|
|
219
|
-
w2_bias = jax.device_put(w2_bias,
|
|
220
|
-
Format(Layout((0, 1)), ep_sharding))
|
|
221
|
-
|
|
222
|
-
else:
|
|
223
|
-
output_sizes = [intermediate_size, intermediate_size]
|
|
224
|
-
n_shards = self.mesh.shape["model"]
|
|
225
|
-
assert intermediate_size % n_shards == 0
|
|
226
|
-
|
|
227
|
-
w13_weight = reorder_concatenated_tensor_for_sharding(
|
|
228
|
-
w13_weight,
|
|
229
|
-
output_sizes,
|
|
230
|
-
n_shards,
|
|
231
|
-
dim=1,
|
|
232
|
-
)
|
|
233
|
-
w13_weight = jax.device_put(
|
|
234
|
-
w13_weight,
|
|
235
|
-
Format(Layout((0, 1, 2)),
|
|
236
|
-
NamedSharding(self.mesh, P(None, "model", None))))
|
|
237
|
-
w2_weight = jax.device_put(
|
|
238
|
-
w2_weight,
|
|
239
|
-
Format(Layout((0, 1, 2)),
|
|
240
|
-
NamedSharding(self.mesh, P(None, None, "model"))))
|
|
241
|
-
|
|
242
|
-
w13_bias = reorder_concatenated_tensor_for_sharding(
|
|
243
|
-
w13_bias,
|
|
244
|
-
output_sizes,
|
|
245
|
-
n_shards,
|
|
246
|
-
dim=1,
|
|
247
|
-
)
|
|
248
|
-
w13_bias = jax.device_put(
|
|
249
|
-
w13_bias,
|
|
250
|
-
Format(Layout((0, 1)),
|
|
251
|
-
NamedSharding(self.mesh, P(None, "model"))))
|
|
252
|
-
w2_bias = jax.device_put(
|
|
253
|
-
w2_bias,
|
|
254
|
-
Format(Layout((0, 1)),
|
|
255
|
-
NamedSharding(self.mesh, P(None, None))))
|
|
256
|
-
|
|
257
|
-
layer.w13_weight = Parameter(torch_view(w13_weight),
|
|
258
|
-
requires_grad=False)
|
|
259
|
-
layer.w2_weight = Parameter(torch_view(w2_weight), requires_grad=False)
|
|
260
|
-
|
|
261
|
-
layer.w13_bias = Parameter(torch_view(w13_bias), requires_grad=False)
|
|
262
|
-
layer.w2_bias = Parameter(torch_view(w2_bias), requires_grad=False)
|
|
263
|
-
|
|
264
|
-
pass
|
|
265
|
-
|
|
266
|
-
def apply(
|
|
267
|
-
self,
|
|
268
|
-
layer: torch.nn.Module,
|
|
269
|
-
x: torch.Tensor,
|
|
270
|
-
router_logits: torch.Tensor,
|
|
271
|
-
top_k: int,
|
|
272
|
-
renormalize: bool,
|
|
273
|
-
use_grouped_topk: bool = False,
|
|
274
|
-
topk_group: Optional[int] = None,
|
|
275
|
-
num_expert_group: Optional[int] = None,
|
|
276
|
-
global_num_experts: int = -1,
|
|
277
|
-
expert_map: Optional[torch.Tensor] = None,
|
|
278
|
-
custom_routing_function: Optional[Callable] = None,
|
|
279
|
-
scoring_func: str = "softmax",
|
|
280
|
-
routed_scaling_factor: float = 1.0,
|
|
281
|
-
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
282
|
-
apply_router_weight_on_input: bool = False,
|
|
283
|
-
activation: str = "silu",
|
|
284
|
-
enable_eplb: bool = False,
|
|
285
|
-
expert_load_view: Optional[torch.Tensor] = None,
|
|
286
|
-
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
287
|
-
logical_replica_count: Optional[torch.Tensor] = None,
|
|
288
|
-
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
|
289
|
-
assert isinstance(layer, FusedMoE)
|
|
290
|
-
if scoring_func != "softmax":
|
|
291
|
-
raise NotImplementedError(
|
|
292
|
-
"Only softmax is supported for scoring_func")
|
|
293
|
-
|
|
294
|
-
x = jax_view(x)
|
|
295
|
-
w13_weight = jax_view(layer.w13_weight)
|
|
296
|
-
w2_weight = jax_view(layer.w2_weight)
|
|
297
|
-
w13_bias = jax_view(layer.w13_bias)
|
|
298
|
-
w2_bias = jax_view(layer.w2_bias)
|
|
299
|
-
gating_output = jax_view(router_logits)
|
|
300
|
-
|
|
301
|
-
if self.use_kernel:
|
|
302
|
-
output = fused_ep_moe(
|
|
303
|
-
mesh=self.mesh,
|
|
304
|
-
tokens=x,
|
|
305
|
-
w1=w13_weight,
|
|
306
|
-
w2=w2_weight,
|
|
307
|
-
b1=w13_bias,
|
|
308
|
-
b2=w2_bias,
|
|
309
|
-
gating_output=gating_output,
|
|
310
|
-
top_k=top_k,
|
|
311
|
-
ep_axis_name=self.ep_axis_name,
|
|
312
|
-
renormalize_topk_logits=renormalize,
|
|
313
|
-
act_fn=activation,
|
|
314
|
-
**self.block_size,
|
|
315
|
-
)
|
|
316
|
-
else:
|
|
317
|
-
output = fused_moe_func(
|
|
318
|
-
hidden_states=x,
|
|
319
|
-
w1=w13_weight,
|
|
320
|
-
w2=w2_weight,
|
|
321
|
-
w1_bias=w13_bias,
|
|
322
|
-
w2_bias=w2_bias,
|
|
323
|
-
gating_output=gating_output,
|
|
324
|
-
topk=top_k,
|
|
325
|
-
renormalize=renormalize,
|
|
326
|
-
mesh=self.mesh,
|
|
327
|
-
use_ep=layer.use_ep,
|
|
328
|
-
activation=activation,
|
|
329
|
-
)
|
|
330
|
-
|
|
331
|
-
return torch_view(output)
|
|
@@ -1,361 +0,0 @@
|
|
|
1
|
-
import re
|
|
2
|
-
from typing import Any, List, Optional, Tuple
|
|
3
|
-
|
|
4
|
-
import jax
|
|
5
|
-
import jax.numpy as jnp
|
|
6
|
-
import torch
|
|
7
|
-
from flax import nnx
|
|
8
|
-
from flax.typing import PRNGKey
|
|
9
|
-
from jax.sharding import Mesh
|
|
10
|
-
from jax.sharding import PartitionSpec as P
|
|
11
|
-
from vllm.config import VllmConfig
|
|
12
|
-
|
|
13
|
-
from tpu_inference.layers.jax.attention.attention import AttentionMetadata
|
|
14
|
-
from tpu_inference.layers.jax.attention.llama4_attention import Llama4Attention
|
|
15
|
-
from tpu_inference.layers.jax.constants import KVCacheType
|
|
16
|
-
from tpu_inference.layers.jax.layers import DenseFFW, Embedder, LMhead, RMSNorm
|
|
17
|
-
from tpu_inference.layers.jax.misc import shard_put
|
|
18
|
-
from tpu_inference.layers.jax.transformer_block import TransformerBlock
|
|
19
|
-
from tpu_inference.logger import init_logger
|
|
20
|
-
from tpu_inference.models.jax.utils.weight_utils import (
|
|
21
|
-
get_param, model_weights_generator, print_param_info, reshape_params,
|
|
22
|
-
transpose_params)
|
|
23
|
-
|
|
24
|
-
logger = init_logger(__name__)
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
class LlamaGuard4ForCausalLM(nnx.Module):
|
|
28
|
-
|
|
29
|
-
def __init__(self,
|
|
30
|
-
vllm_config: VllmConfig,
|
|
31
|
-
rng: PRNGKey,
|
|
32
|
-
mesh: Mesh,
|
|
33
|
-
force_random_weights: bool = False):
|
|
34
|
-
logger.warning(
|
|
35
|
-
"🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨\n"
|
|
36
|
-
"Llama Guard 4 (JAX) is WIP: Only the text modality is currently implemented. "
|
|
37
|
-
"Multimodal inputs will fail.\n"
|
|
38
|
-
"🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨")
|
|
39
|
-
assert mesh is not None
|
|
40
|
-
|
|
41
|
-
self.vllm_config = vllm_config
|
|
42
|
-
self.vllm_config.model_config.dtype = torch.bfloat16
|
|
43
|
-
model_config = vllm_config.model_config
|
|
44
|
-
text_config = model_config.hf_config.text_config
|
|
45
|
-
|
|
46
|
-
self.mesh = mesh
|
|
47
|
-
self.is_verbose = getattr(self.vllm_config.additional_config,
|
|
48
|
-
"is_verbose", False)
|
|
49
|
-
|
|
50
|
-
self.use_qk_norm = getattr(text_config, "use_qk_norm", True)
|
|
51
|
-
|
|
52
|
-
vocab_size = model_config.get_vocab_size()
|
|
53
|
-
self.hidden_size = model_config.get_hidden_size()
|
|
54
|
-
|
|
55
|
-
self.dtype: jnp.dtype = jnp.bfloat16
|
|
56
|
-
|
|
57
|
-
self.num_layers: int = getattr(text_config, "num_layers", 48)
|
|
58
|
-
hidden_act: str = getattr(text_config, "hidden_act", "silu")
|
|
59
|
-
|
|
60
|
-
rms_norm_eps = getattr(text_config, "rms_norm_eps", 1e-5)
|
|
61
|
-
self.num_attention_heads = getattr(text_config, "num_attention_heads",
|
|
62
|
-
40)
|
|
63
|
-
self.num_key_value_heads = getattr(text_config, "num_key_value_heads",
|
|
64
|
-
8)
|
|
65
|
-
self.head_dim = getattr(text_config, "head_dim", 128)
|
|
66
|
-
|
|
67
|
-
intermediate_size = getattr(text_config, "intermediate_size", 8192)
|
|
68
|
-
|
|
69
|
-
self.rope_theta_text = getattr(text_config, "rope_theta", 500000.0)
|
|
70
|
-
self.rope_scaling = getattr(text_config, "rope_scaling")
|
|
71
|
-
|
|
72
|
-
self.rng = nnx.Rngs(rng)
|
|
73
|
-
|
|
74
|
-
self.embedder = Embedder(
|
|
75
|
-
vocab_size=vocab_size,
|
|
76
|
-
hidden_size=self.hidden_size,
|
|
77
|
-
dtype=self.dtype,
|
|
78
|
-
vd_sharding=(('data', 'model'), None),
|
|
79
|
-
rngs=self.rng,
|
|
80
|
-
random_init=force_random_weights,
|
|
81
|
-
)
|
|
82
|
-
|
|
83
|
-
self.layers = []
|
|
84
|
-
|
|
85
|
-
for i in range(self.num_layers):
|
|
86
|
-
use_attention_rope = True
|
|
87
|
-
|
|
88
|
-
custom_module = DenseFFW(dtype=self.dtype,
|
|
89
|
-
hidden_act=hidden_act,
|
|
90
|
-
hidden_size=self.hidden_size,
|
|
91
|
-
intermediate_size=intermediate_size,
|
|
92
|
-
random_init=force_random_weights,
|
|
93
|
-
rngs=self.rng,
|
|
94
|
-
df_sharding=P(None, 'model'),
|
|
95
|
-
fd_sharding=P('model', None),
|
|
96
|
-
activation_ffw_td=P('data', None))
|
|
97
|
-
|
|
98
|
-
attn = Llama4Attention(
|
|
99
|
-
hidden_size=self.hidden_size,
|
|
100
|
-
dtype=self.dtype,
|
|
101
|
-
num_attention_heads=self.num_attention_heads,
|
|
102
|
-
num_key_value_heads=self.num_key_value_heads,
|
|
103
|
-
head_dim=self.head_dim,
|
|
104
|
-
rope_theta=self.rope_theta_text,
|
|
105
|
-
rope_scaling={
|
|
106
|
-
"scale_factor":
|
|
107
|
-
self.rope_scaling["factor"],
|
|
108
|
-
"low_freq_factor":
|
|
109
|
-
self.rope_scaling["low_freq_factor"],
|
|
110
|
-
"high_freq_factor":
|
|
111
|
-
self.rope_scaling["high_freq_factor"],
|
|
112
|
-
"original_max_position_embeddings":
|
|
113
|
-
self.rope_scaling["original_max_position_embeddings"]
|
|
114
|
-
},
|
|
115
|
-
rngs=self.rng,
|
|
116
|
-
rope_input_ordering="interleaved",
|
|
117
|
-
# TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
|
|
118
|
-
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
|
119
|
-
temperature_tuning=True,
|
|
120
|
-
temperature_tuning_scale=0.1,
|
|
121
|
-
temperature_tuning_floor_scale=8192,
|
|
122
|
-
use_qk_norm=self.use_qk_norm,
|
|
123
|
-
attention_chunk_size=None if use_attention_rope else 8192,
|
|
124
|
-
mesh=self.mesh,
|
|
125
|
-
random_init=force_random_weights,
|
|
126
|
-
activation_attention_td=('data', 'model'),
|
|
127
|
-
activation_q_td=('data', 'model'),
|
|
128
|
-
query_tnh=P('data', 'model', None),
|
|
129
|
-
keyvalue_skh=P('data', 'model', None),
|
|
130
|
-
activation_attention_out_td=('data', 'model'),
|
|
131
|
-
attn_o_tnh=P('data', 'model', None),
|
|
132
|
-
dnh_sharding=(None, 'model', None),
|
|
133
|
-
dkh_sharding=(None, 'model', None),
|
|
134
|
-
nhd_sharding=('model', None, None),
|
|
135
|
-
)
|
|
136
|
-
|
|
137
|
-
pre_attention_norm = RMSNorm(
|
|
138
|
-
dims=self.hidden_size,
|
|
139
|
-
random_init=force_random_weights,
|
|
140
|
-
epsilon=rms_norm_eps,
|
|
141
|
-
rngs=self.rng,
|
|
142
|
-
activation_ffw_td=('data', None),
|
|
143
|
-
with_scale=True,
|
|
144
|
-
dtype=self.dtype,
|
|
145
|
-
)
|
|
146
|
-
|
|
147
|
-
pre_mlp_norm = RMSNorm(
|
|
148
|
-
dims=self.hidden_size,
|
|
149
|
-
activation_ffw_td=('data', None),
|
|
150
|
-
epsilon=rms_norm_eps,
|
|
151
|
-
rngs=self.rng,
|
|
152
|
-
with_scale=True,
|
|
153
|
-
dtype=self.dtype,
|
|
154
|
-
random_init=force_random_weights,
|
|
155
|
-
)
|
|
156
|
-
|
|
157
|
-
block = TransformerBlock(custom_module=custom_module,
|
|
158
|
-
attn=attn,
|
|
159
|
-
pre_attention_norm=pre_attention_norm,
|
|
160
|
-
pre_mlp_norm=pre_mlp_norm,
|
|
161
|
-
use_attention_rope=use_attention_rope)
|
|
162
|
-
self.layers.append(block)
|
|
163
|
-
|
|
164
|
-
self.final_norm = RMSNorm(
|
|
165
|
-
dims=self.hidden_size,
|
|
166
|
-
activation_ffw_td=P(),
|
|
167
|
-
epsilon=rms_norm_eps,
|
|
168
|
-
rngs=self.rng,
|
|
169
|
-
with_scale=True,
|
|
170
|
-
dtype=self.dtype,
|
|
171
|
-
random_init=force_random_weights,
|
|
172
|
-
)
|
|
173
|
-
|
|
174
|
-
self.lm_head = LMhead(vocab_size=vocab_size,
|
|
175
|
-
hidden_size=self.hidden_size,
|
|
176
|
-
dtype=self.dtype,
|
|
177
|
-
rngs=self.rng,
|
|
178
|
-
vd_sharding=(('data', 'model'), None),
|
|
179
|
-
dv_sharding=(None, ('data', 'model')),
|
|
180
|
-
random_init=force_random_weights)
|
|
181
|
-
if self.is_verbose:
|
|
182
|
-
self._print_model_architecture()
|
|
183
|
-
|
|
184
|
-
def _print_model_architecture(self):
|
|
185
|
-
|
|
186
|
-
logger.info("### Embedding ###")
|
|
187
|
-
nnx.display(self.embedder)
|
|
188
|
-
|
|
189
|
-
logger.info("\n### Layers ###")
|
|
190
|
-
for i, layer in enumerate(self.layers):
|
|
191
|
-
logger.info(f"\n--- Layer {i} ---")
|
|
192
|
-
nnx.display(layer)
|
|
193
|
-
|
|
194
|
-
logger.info("\n### LM Head ###")
|
|
195
|
-
nnx.display(self.lm_head)
|
|
196
|
-
|
|
197
|
-
def load_weights(self, rng: jax.Array, cache_dir: Optional[str] = None):
|
|
198
|
-
self.rng = nnx.Rngs(rng)
|
|
199
|
-
|
|
200
|
-
weight_loader = LlamaGuard4WeightLoader(
|
|
201
|
-
vllm_config=self.vllm_config,
|
|
202
|
-
hidden_size=self.hidden_size,
|
|
203
|
-
attn_heads=self.num_attention_heads,
|
|
204
|
-
num_key_value_heads=self.num_key_value_heads,
|
|
205
|
-
attn_head_dim=self.head_dim)
|
|
206
|
-
weight_loader.load_weights(self)
|
|
207
|
-
|
|
208
|
-
def __call__(
|
|
209
|
-
self,
|
|
210
|
-
kv_caches: List[jax.Array],
|
|
211
|
-
input_ids: jax.Array,
|
|
212
|
-
attention_metadata: AttentionMetadata,
|
|
213
|
-
inputs_embeds: Optional[jax.Array] = None,
|
|
214
|
-
layer_metadata_tuple: Optional[Tuple] = None,
|
|
215
|
-
lora_metadata: Optional[Any] = None,
|
|
216
|
-
*args,
|
|
217
|
-
) -> Tuple[List[KVCacheType], jax.Array]:
|
|
218
|
-
is_prefill = False
|
|
219
|
-
|
|
220
|
-
if inputs_embeds is not None:
|
|
221
|
-
x_TD = inputs_embeds
|
|
222
|
-
elif input_ids is not None:
|
|
223
|
-
x_TD = self.embedder.encode(input_ids)
|
|
224
|
-
else:
|
|
225
|
-
raise ValueError(
|
|
226
|
-
"Cannot run forward pass: Both input_ids and inputs_embeds are None."
|
|
227
|
-
)
|
|
228
|
-
|
|
229
|
-
for (i, block) in enumerate(self.layers):
|
|
230
|
-
kv_cache = kv_caches[i]
|
|
231
|
-
new_kv_cache, x_TD = block(x_TD, is_prefill, kv_cache,
|
|
232
|
-
attention_metadata)
|
|
233
|
-
jax.block_until_ready(x_TD)
|
|
234
|
-
kv_caches[i] = new_kv_cache
|
|
235
|
-
|
|
236
|
-
final_activation_TD = self.final_norm(x_TD)
|
|
237
|
-
|
|
238
|
-
return kv_caches, final_activation_TD, []
|
|
239
|
-
|
|
240
|
-
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
|
|
241
|
-
logits_TV = jnp.dot(hidden_states,
|
|
242
|
-
self.lm_head.input_embedding_table_DV.value)
|
|
243
|
-
return logits_TV
|
|
244
|
-
|
|
245
|
-
def get_input_embeddings(
|
|
246
|
-
self,
|
|
247
|
-
input_ids: jax.Array,
|
|
248
|
-
multimodal_embeddings: Optional[List[jax.Array]] = None
|
|
249
|
-
) -> jax.Array:
|
|
250
|
-
"""
|
|
251
|
-
Computes the embeddings for text input (used for input to fusion).
|
|
252
|
-
"""
|
|
253
|
-
return self.embedder.encode(input_ids)
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
class LlamaGuard4WeightLoader:
|
|
257
|
-
|
|
258
|
-
def __init__(self, vllm_config: VllmConfig, hidden_size, attn_heads,
|
|
259
|
-
num_key_value_heads, attn_head_dim):
|
|
260
|
-
self.names_and_weights_generator = model_weights_generator(
|
|
261
|
-
model_name_or_path=vllm_config.model_config.model,
|
|
262
|
-
framework="flax",
|
|
263
|
-
filter_regex="language_model",
|
|
264
|
-
download_dir=vllm_config.load_config.download_dir)
|
|
265
|
-
self.is_verbose = getattr(vllm_config.additional_config, "is_verbose",
|
|
266
|
-
False)
|
|
267
|
-
self._transpose_map = {
|
|
268
|
-
"q_proj": (2, 0, 1),
|
|
269
|
-
"k_proj": (2, 0, 1),
|
|
270
|
-
"v_proj": (2, 0, 1),
|
|
271
|
-
"o_proj": (1, 2, 0),
|
|
272
|
-
"lm_head": (1, 0),
|
|
273
|
-
"feed_forward.down_proj": (1, 0),
|
|
274
|
-
"feed_forward.gate_proj": (1, 0),
|
|
275
|
-
"feed_forward.up_proj": (1, 0),
|
|
276
|
-
"mlp.down_proj": (1, 0),
|
|
277
|
-
"mlp.gate_proj": (1, 0),
|
|
278
|
-
"mlp.up_proj": (1, 0),
|
|
279
|
-
}
|
|
280
|
-
self._weight_shape_map = {
|
|
281
|
-
"q_proj": (attn_heads, attn_head_dim, hidden_size),
|
|
282
|
-
"k_proj": (num_key_value_heads, attn_head_dim, hidden_size),
|
|
283
|
-
"v_proj": (num_key_value_heads, attn_head_dim, hidden_size),
|
|
284
|
-
"o_proj": (hidden_size, attn_heads, attn_head_dim),
|
|
285
|
-
}
|
|
286
|
-
|
|
287
|
-
self._loaded_to_standardized_keys = {
|
|
288
|
-
"language_model.model.embed_tokens.weight":
|
|
289
|
-
"embedder.input_embedding_table_VD",
|
|
290
|
-
"language_model.lm_head.weight":
|
|
291
|
-
"lm_head.input_embedding_table_DV",
|
|
292
|
-
"language_model.model.norm.weight":
|
|
293
|
-
"final_norm.scale",
|
|
294
|
-
"language_model.model.layers.*.input_layernorm.weight":
|
|
295
|
-
"layers.*.pre_attention_norm.scale",
|
|
296
|
-
"language_model.model.layers.*.post_attention_layernorm.weight":
|
|
297
|
-
"layers.*.pre_mlp_norm.scale",
|
|
298
|
-
"language_model.model.layers.*.self_attn.q_proj.weight":
|
|
299
|
-
"layers.*.attn.kernel_q_proj_DNH",
|
|
300
|
-
"language_model.model.layers.*.self_attn.k_proj.weight":
|
|
301
|
-
"layers.*.attn.kernel_k_proj_DKH",
|
|
302
|
-
"language_model.model.layers.*.self_attn.v_proj.weight":
|
|
303
|
-
"layers.*.attn.kernel_v_proj_DKH",
|
|
304
|
-
"language_model.model.layers.*.self_attn.o_proj.weight":
|
|
305
|
-
"layers.*.attn.kernel_o_proj_NHD",
|
|
306
|
-
"language_model.model.layers.*.feed_forward.gate_proj.weight":
|
|
307
|
-
"layers.*.custom_module.kernel_gating_DF",
|
|
308
|
-
"language_model.model.layers.*.feed_forward.up_proj.weight":
|
|
309
|
-
"layers.*.custom_module.kernel_up_proj_DF",
|
|
310
|
-
"language_model.model.layers.*.feed_forward.down_proj.weight":
|
|
311
|
-
"layers.*.custom_module.kernel_down_proj_FD",
|
|
312
|
-
}
|
|
313
|
-
|
|
314
|
-
def map_loaded_to_standardized_name(self, loaded_key: str) -> str:
|
|
315
|
-
if "layer" in loaded_key:
|
|
316
|
-
layer_num = re.search(r"layers\.(\d+)", loaded_key).group(1)
|
|
317
|
-
layer_key = re.sub(r"layers\.\d+", "layers.*", loaded_key)
|
|
318
|
-
mapped_key = self._loaded_to_standardized_keys.get(
|
|
319
|
-
layer_key, loaded_key)
|
|
320
|
-
mapped_key = re.sub(r"layers\.\*", f"layers.{layer_num}",
|
|
321
|
-
mapped_key)
|
|
322
|
-
else:
|
|
323
|
-
mapped_key = self._loaded_to_standardized_keys.get(
|
|
324
|
-
loaded_key, loaded_key)
|
|
325
|
-
return mapped_key
|
|
326
|
-
|
|
327
|
-
def load_weights(self, model_for_loading: nnx.Module):
|
|
328
|
-
model_params = nnx.state(model_for_loading)
|
|
329
|
-
with jax.default_device(jax.devices("cpu")[0]):
|
|
330
|
-
for loaded_name, loaded_weight in self.names_and_weights_generator:
|
|
331
|
-
if loaded_name.endswith(".bias"):
|
|
332
|
-
continue
|
|
333
|
-
if "vision_model" in loaded_name or "multi_modal_projector" in loaded_name:
|
|
334
|
-
continue
|
|
335
|
-
|
|
336
|
-
mapped_name = self.map_loaded_to_standardized_name(loaded_name)
|
|
337
|
-
model_weight = get_param(model_params, mapped_name)
|
|
338
|
-
|
|
339
|
-
if not loaded_name.endswith(".bias"):
|
|
340
|
-
# For other layers, continue to use the transpose_params helper.
|
|
341
|
-
loaded_weight = reshape_params(loaded_name, loaded_weight,
|
|
342
|
-
self._weight_shape_map)
|
|
343
|
-
loaded_weight = transpose_params(loaded_name,
|
|
344
|
-
loaded_weight,
|
|
345
|
-
self._transpose_map)
|
|
346
|
-
if model_weight.value.shape != loaded_weight.shape:
|
|
347
|
-
raise ValueError(
|
|
348
|
-
f"Loaded shape for {loaded_name}: {loaded_weight.shape} "
|
|
349
|
-
f"does not match model shape for {mapped_name}: {model_weight.value.shape}!"
|
|
350
|
-
)
|
|
351
|
-
logger.debug(
|
|
352
|
-
f"Transformed parameter {loaded_name} to {mapped_name}: {loaded_weight.shape} --> {model_weight.value.shape}"
|
|
353
|
-
)
|
|
354
|
-
|
|
355
|
-
model_weight.value = shard_put(loaded_weight,
|
|
356
|
-
model_weight.sharding,
|
|
357
|
-
mesh=model_for_loading.mesh)
|
|
358
|
-
if self.is_verbose:
|
|
359
|
-
print_param_info(model_weight, loaded_name)
|
|
360
|
-
|
|
361
|
-
nnx.update(model_for_loading, model_params)
|
|
File without changes
|
|
File without changes
|