tpu-inference 0.11.1.dev202511150811__py3-none-any.whl → 0.11.1.dev202511270815__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +2 -3
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +1 -1
- tpu_inference/executors/ray_distributed_executor.py +27 -11
- tpu_inference/kernels/fused_moe/v1/kernel.py +641 -110
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +141 -107
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +2 -1
- tpu_inference/layers/vllm/fused_moe.py +74 -25
- tpu_inference/layers/vllm/quantization/common.py +6 -1
- tpu_inference/layers/vllm/quantization/mxfp4.py +135 -61
- tpu_inference/layers/vllm/quantization/unquantized.py +107 -113
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +43 -11
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +8 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +163 -48
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/weight_utils.py +198 -143
- tpu_inference/models/vllm/vllm_model_wrapper.py +13 -5
- tpu_inference/platforms/tpu_platform.py +15 -2
- tpu_inference/runner/compilation_manager.py +58 -33
- tpu_inference/runner/kv_cache_manager.py +9 -3
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +203 -102
- tpu_inference/spec_decode/jax/eagle3.py +19 -2
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +5 -4
- tpu_inference/worker/tpu_worker.py +160 -23
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/METADATA +3 -2
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/RECORD +43 -48
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511150811.dist-info → tpu_inference-0.11.1.dev202511270815.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,361 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import Any, List, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import torch
|
|
7
|
+
from flax import nnx
|
|
8
|
+
from flax.typing import PRNGKey
|
|
9
|
+
from jax.sharding import Mesh
|
|
10
|
+
from jax.sharding import PartitionSpec as P
|
|
11
|
+
from vllm.config import VllmConfig
|
|
12
|
+
|
|
13
|
+
from tpu_inference.layers.jax.attention.attention import AttentionMetadata
|
|
14
|
+
from tpu_inference.layers.jax.attention.llama4_attention import Llama4Attention
|
|
15
|
+
from tpu_inference.layers.jax.constants import KVCacheType
|
|
16
|
+
from tpu_inference.layers.jax.layers import DenseFFW, Embedder, LMhead, RMSNorm
|
|
17
|
+
from tpu_inference.layers.jax.misc import shard_put
|
|
18
|
+
from tpu_inference.layers.jax.transformer_block import TransformerBlock
|
|
19
|
+
from tpu_inference.logger import init_logger
|
|
20
|
+
from tpu_inference.models.jax.utils.weight_utils import (
|
|
21
|
+
get_param, model_weights_generator, print_param_info, reshape_params,
|
|
22
|
+
transpose_params)
|
|
23
|
+
|
|
24
|
+
logger = init_logger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class LlamaGuard4ForCausalLM(nnx.Module):
|
|
28
|
+
|
|
29
|
+
def __init__(self,
|
|
30
|
+
vllm_config: VllmConfig,
|
|
31
|
+
rng: PRNGKey,
|
|
32
|
+
mesh: Mesh,
|
|
33
|
+
force_random_weights: bool = False):
|
|
34
|
+
logger.warning(
|
|
35
|
+
"🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨\n"
|
|
36
|
+
"Llama Guard 4 (JAX) is WIP: Only the text modality is currently implemented. "
|
|
37
|
+
"Multimodal inputs will fail.\n"
|
|
38
|
+
"🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨")
|
|
39
|
+
assert mesh is not None
|
|
40
|
+
|
|
41
|
+
self.vllm_config = vllm_config
|
|
42
|
+
self.vllm_config.model_config.dtype = torch.bfloat16
|
|
43
|
+
model_config = vllm_config.model_config
|
|
44
|
+
text_config = model_config.hf_config.text_config
|
|
45
|
+
|
|
46
|
+
self.mesh = mesh
|
|
47
|
+
self.is_verbose = getattr(self.vllm_config.additional_config,
|
|
48
|
+
"is_verbose", False)
|
|
49
|
+
|
|
50
|
+
self.use_qk_norm = getattr(text_config, "use_qk_norm", True)
|
|
51
|
+
|
|
52
|
+
vocab_size = model_config.get_vocab_size()
|
|
53
|
+
self.hidden_size = model_config.get_hidden_size()
|
|
54
|
+
|
|
55
|
+
self.dtype: jnp.dtype = jnp.bfloat16
|
|
56
|
+
|
|
57
|
+
self.num_layers: int = getattr(text_config, "num_layers", 48)
|
|
58
|
+
hidden_act: str = getattr(text_config, "hidden_act", "silu")
|
|
59
|
+
|
|
60
|
+
rms_norm_eps = getattr(text_config, "rms_norm_eps", 1e-5)
|
|
61
|
+
self.num_attention_heads = getattr(text_config, "num_attention_heads",
|
|
62
|
+
40)
|
|
63
|
+
self.num_key_value_heads = getattr(text_config, "num_key_value_heads",
|
|
64
|
+
8)
|
|
65
|
+
self.head_dim = getattr(text_config, "head_dim", 128)
|
|
66
|
+
|
|
67
|
+
intermediate_size = getattr(text_config, "intermediate_size", 8192)
|
|
68
|
+
|
|
69
|
+
self.rope_theta_text = getattr(text_config, "rope_theta", 500000.0)
|
|
70
|
+
self.rope_scaling = getattr(text_config, "rope_scaling")
|
|
71
|
+
|
|
72
|
+
self.rng = nnx.Rngs(rng)
|
|
73
|
+
|
|
74
|
+
self.embedder = Embedder(
|
|
75
|
+
vocab_size=vocab_size,
|
|
76
|
+
hidden_size=self.hidden_size,
|
|
77
|
+
dtype=self.dtype,
|
|
78
|
+
vd_sharding=(('data', 'model'), None),
|
|
79
|
+
rngs=self.rng,
|
|
80
|
+
random_init=force_random_weights,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
self.layers = []
|
|
84
|
+
|
|
85
|
+
for i in range(self.num_layers):
|
|
86
|
+
use_attention_rope = True
|
|
87
|
+
|
|
88
|
+
custom_module = DenseFFW(dtype=self.dtype,
|
|
89
|
+
hidden_act=hidden_act,
|
|
90
|
+
hidden_size=self.hidden_size,
|
|
91
|
+
intermediate_size=intermediate_size,
|
|
92
|
+
random_init=force_random_weights,
|
|
93
|
+
rngs=self.rng,
|
|
94
|
+
df_sharding=P(None, 'model'),
|
|
95
|
+
fd_sharding=P('model', None),
|
|
96
|
+
activation_ffw_td=P('data', None))
|
|
97
|
+
|
|
98
|
+
attn = Llama4Attention(
|
|
99
|
+
hidden_size=self.hidden_size,
|
|
100
|
+
dtype=self.dtype,
|
|
101
|
+
num_attention_heads=self.num_attention_heads,
|
|
102
|
+
num_key_value_heads=self.num_key_value_heads,
|
|
103
|
+
head_dim=self.head_dim,
|
|
104
|
+
rope_theta=self.rope_theta_text,
|
|
105
|
+
rope_scaling={
|
|
106
|
+
"scale_factor":
|
|
107
|
+
self.rope_scaling["factor"],
|
|
108
|
+
"low_freq_factor":
|
|
109
|
+
self.rope_scaling["low_freq_factor"],
|
|
110
|
+
"high_freq_factor":
|
|
111
|
+
self.rope_scaling["high_freq_factor"],
|
|
112
|
+
"original_max_position_embeddings":
|
|
113
|
+
self.rope_scaling["original_max_position_embeddings"]
|
|
114
|
+
},
|
|
115
|
+
rngs=self.rng,
|
|
116
|
+
rope_input_ordering="interleaved",
|
|
117
|
+
# TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
|
|
118
|
+
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
|
119
|
+
temperature_tuning=True,
|
|
120
|
+
temperature_tuning_scale=0.1,
|
|
121
|
+
temperature_tuning_floor_scale=8192,
|
|
122
|
+
use_qk_norm=self.use_qk_norm,
|
|
123
|
+
attention_chunk_size=None if use_attention_rope else 8192,
|
|
124
|
+
mesh=self.mesh,
|
|
125
|
+
random_init=force_random_weights,
|
|
126
|
+
activation_attention_td=('data', 'model'),
|
|
127
|
+
activation_q_td=('data', 'model'),
|
|
128
|
+
query_tnh=P('data', 'model', None),
|
|
129
|
+
keyvalue_skh=P('data', 'model', None),
|
|
130
|
+
activation_attention_out_td=('data', 'model'),
|
|
131
|
+
attn_o_tnh=P('data', 'model', None),
|
|
132
|
+
dnh_sharding=(None, 'model', None),
|
|
133
|
+
dkh_sharding=(None, 'model', None),
|
|
134
|
+
nhd_sharding=('model', None, None),
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
pre_attention_norm = RMSNorm(
|
|
138
|
+
dims=self.hidden_size,
|
|
139
|
+
random_init=force_random_weights,
|
|
140
|
+
epsilon=rms_norm_eps,
|
|
141
|
+
rngs=self.rng,
|
|
142
|
+
activation_ffw_td=('data', None),
|
|
143
|
+
with_scale=True,
|
|
144
|
+
dtype=self.dtype,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
pre_mlp_norm = RMSNorm(
|
|
148
|
+
dims=self.hidden_size,
|
|
149
|
+
activation_ffw_td=('data', None),
|
|
150
|
+
epsilon=rms_norm_eps,
|
|
151
|
+
rngs=self.rng,
|
|
152
|
+
with_scale=True,
|
|
153
|
+
dtype=self.dtype,
|
|
154
|
+
random_init=force_random_weights,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
block = TransformerBlock(custom_module=custom_module,
|
|
158
|
+
attn=attn,
|
|
159
|
+
pre_attention_norm=pre_attention_norm,
|
|
160
|
+
pre_mlp_norm=pre_mlp_norm,
|
|
161
|
+
use_attention_rope=use_attention_rope)
|
|
162
|
+
self.layers.append(block)
|
|
163
|
+
|
|
164
|
+
self.final_norm = RMSNorm(
|
|
165
|
+
dims=self.hidden_size,
|
|
166
|
+
activation_ffw_td=P(),
|
|
167
|
+
epsilon=rms_norm_eps,
|
|
168
|
+
rngs=self.rng,
|
|
169
|
+
with_scale=True,
|
|
170
|
+
dtype=self.dtype,
|
|
171
|
+
random_init=force_random_weights,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
self.lm_head = LMhead(vocab_size=vocab_size,
|
|
175
|
+
hidden_size=self.hidden_size,
|
|
176
|
+
dtype=self.dtype,
|
|
177
|
+
rngs=self.rng,
|
|
178
|
+
vd_sharding=(('data', 'model'), None),
|
|
179
|
+
dv_sharding=(None, ('data', 'model')),
|
|
180
|
+
random_init=force_random_weights)
|
|
181
|
+
if self.is_verbose:
|
|
182
|
+
self._print_model_architecture()
|
|
183
|
+
|
|
184
|
+
def _print_model_architecture(self):
|
|
185
|
+
|
|
186
|
+
logger.info("### Embedding ###")
|
|
187
|
+
nnx.display(self.embedder)
|
|
188
|
+
|
|
189
|
+
logger.info("\n### Layers ###")
|
|
190
|
+
for i, layer in enumerate(self.layers):
|
|
191
|
+
logger.info(f"\n--- Layer {i} ---")
|
|
192
|
+
nnx.display(layer)
|
|
193
|
+
|
|
194
|
+
logger.info("\n### LM Head ###")
|
|
195
|
+
nnx.display(self.lm_head)
|
|
196
|
+
|
|
197
|
+
def load_weights(self, rng: jax.Array, cache_dir: Optional[str] = None):
|
|
198
|
+
self.rng = nnx.Rngs(rng)
|
|
199
|
+
|
|
200
|
+
weight_loader = LlamaGuard4WeightLoader(
|
|
201
|
+
vllm_config=self.vllm_config,
|
|
202
|
+
hidden_size=self.hidden_size,
|
|
203
|
+
attn_heads=self.num_attention_heads,
|
|
204
|
+
num_key_value_heads=self.num_key_value_heads,
|
|
205
|
+
attn_head_dim=self.head_dim)
|
|
206
|
+
weight_loader.load_weights(self)
|
|
207
|
+
|
|
208
|
+
def __call__(
|
|
209
|
+
self,
|
|
210
|
+
kv_caches: List[jax.Array],
|
|
211
|
+
input_ids: jax.Array,
|
|
212
|
+
attention_metadata: AttentionMetadata,
|
|
213
|
+
inputs_embeds: Optional[jax.Array] = None,
|
|
214
|
+
layer_metadata_tuple: Optional[Tuple] = None,
|
|
215
|
+
lora_metadata: Optional[Any] = None,
|
|
216
|
+
*args,
|
|
217
|
+
) -> Tuple[List[KVCacheType], jax.Array]:
|
|
218
|
+
is_prefill = False
|
|
219
|
+
|
|
220
|
+
if inputs_embeds is not None:
|
|
221
|
+
x_TD = inputs_embeds
|
|
222
|
+
elif input_ids is not None:
|
|
223
|
+
x_TD = self.embedder.encode(input_ids)
|
|
224
|
+
else:
|
|
225
|
+
raise ValueError(
|
|
226
|
+
"Cannot run forward pass: Both input_ids and inputs_embeds are None."
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
for (i, block) in enumerate(self.layers):
|
|
230
|
+
kv_cache = kv_caches[i]
|
|
231
|
+
new_kv_cache, x_TD = block(x_TD, is_prefill, kv_cache,
|
|
232
|
+
attention_metadata)
|
|
233
|
+
jax.block_until_ready(x_TD)
|
|
234
|
+
kv_caches[i] = new_kv_cache
|
|
235
|
+
|
|
236
|
+
final_activation_TD = self.final_norm(x_TD)
|
|
237
|
+
|
|
238
|
+
return kv_caches, final_activation_TD, []
|
|
239
|
+
|
|
240
|
+
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
|
|
241
|
+
logits_TV = jnp.dot(hidden_states,
|
|
242
|
+
self.lm_head.input_embedding_table_DV.value)
|
|
243
|
+
return logits_TV
|
|
244
|
+
|
|
245
|
+
def get_input_embeddings(
|
|
246
|
+
self,
|
|
247
|
+
input_ids: jax.Array,
|
|
248
|
+
multimodal_embeddings: Optional[List[jax.Array]] = None
|
|
249
|
+
) -> jax.Array:
|
|
250
|
+
"""
|
|
251
|
+
Computes the embeddings for text input (used for input to fusion).
|
|
252
|
+
"""
|
|
253
|
+
return self.embedder.encode(input_ids)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
class LlamaGuard4WeightLoader:
|
|
257
|
+
|
|
258
|
+
def __init__(self, vllm_config: VllmConfig, hidden_size, attn_heads,
|
|
259
|
+
num_key_value_heads, attn_head_dim):
|
|
260
|
+
self.names_and_weights_generator = model_weights_generator(
|
|
261
|
+
model_name_or_path=vllm_config.model_config.model,
|
|
262
|
+
framework="flax",
|
|
263
|
+
filter_regex="language_model",
|
|
264
|
+
download_dir=vllm_config.load_config.download_dir)
|
|
265
|
+
self.is_verbose = getattr(vllm_config.additional_config, "is_verbose",
|
|
266
|
+
False)
|
|
267
|
+
self._transpose_map = {
|
|
268
|
+
"q_proj": (2, 0, 1),
|
|
269
|
+
"k_proj": (2, 0, 1),
|
|
270
|
+
"v_proj": (2, 0, 1),
|
|
271
|
+
"o_proj": (1, 2, 0),
|
|
272
|
+
"lm_head": (1, 0),
|
|
273
|
+
"feed_forward.down_proj": (1, 0),
|
|
274
|
+
"feed_forward.gate_proj": (1, 0),
|
|
275
|
+
"feed_forward.up_proj": (1, 0),
|
|
276
|
+
"mlp.down_proj": (1, 0),
|
|
277
|
+
"mlp.gate_proj": (1, 0),
|
|
278
|
+
"mlp.up_proj": (1, 0),
|
|
279
|
+
}
|
|
280
|
+
self._weight_shape_map = {
|
|
281
|
+
"q_proj": (attn_heads, attn_head_dim, hidden_size),
|
|
282
|
+
"k_proj": (num_key_value_heads, attn_head_dim, hidden_size),
|
|
283
|
+
"v_proj": (num_key_value_heads, attn_head_dim, hidden_size),
|
|
284
|
+
"o_proj": (hidden_size, attn_heads, attn_head_dim),
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
self._loaded_to_standardized_keys = {
|
|
288
|
+
"language_model.model.embed_tokens.weight":
|
|
289
|
+
"embedder.input_embedding_table_VD",
|
|
290
|
+
"language_model.lm_head.weight":
|
|
291
|
+
"lm_head.input_embedding_table_DV",
|
|
292
|
+
"language_model.model.norm.weight":
|
|
293
|
+
"final_norm.scale",
|
|
294
|
+
"language_model.model.layers.*.input_layernorm.weight":
|
|
295
|
+
"layers.*.pre_attention_norm.scale",
|
|
296
|
+
"language_model.model.layers.*.post_attention_layernorm.weight":
|
|
297
|
+
"layers.*.pre_mlp_norm.scale",
|
|
298
|
+
"language_model.model.layers.*.self_attn.q_proj.weight":
|
|
299
|
+
"layers.*.attn.kernel_q_proj_DNH",
|
|
300
|
+
"language_model.model.layers.*.self_attn.k_proj.weight":
|
|
301
|
+
"layers.*.attn.kernel_k_proj_DKH",
|
|
302
|
+
"language_model.model.layers.*.self_attn.v_proj.weight":
|
|
303
|
+
"layers.*.attn.kernel_v_proj_DKH",
|
|
304
|
+
"language_model.model.layers.*.self_attn.o_proj.weight":
|
|
305
|
+
"layers.*.attn.kernel_o_proj_NHD",
|
|
306
|
+
"language_model.model.layers.*.feed_forward.gate_proj.weight":
|
|
307
|
+
"layers.*.custom_module.kernel_gating_DF",
|
|
308
|
+
"language_model.model.layers.*.feed_forward.up_proj.weight":
|
|
309
|
+
"layers.*.custom_module.kernel_up_proj_DF",
|
|
310
|
+
"language_model.model.layers.*.feed_forward.down_proj.weight":
|
|
311
|
+
"layers.*.custom_module.kernel_down_proj_FD",
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
def map_loaded_to_standardized_name(self, loaded_key: str) -> str:
|
|
315
|
+
if "layer" in loaded_key:
|
|
316
|
+
layer_num = re.search(r"layers\.(\d+)", loaded_key).group(1)
|
|
317
|
+
layer_key = re.sub(r"layers\.\d+", "layers.*", loaded_key)
|
|
318
|
+
mapped_key = self._loaded_to_standardized_keys.get(
|
|
319
|
+
layer_key, loaded_key)
|
|
320
|
+
mapped_key = re.sub(r"layers\.\*", f"layers.{layer_num}",
|
|
321
|
+
mapped_key)
|
|
322
|
+
else:
|
|
323
|
+
mapped_key = self._loaded_to_standardized_keys.get(
|
|
324
|
+
loaded_key, loaded_key)
|
|
325
|
+
return mapped_key
|
|
326
|
+
|
|
327
|
+
def load_weights(self, model_for_loading: nnx.Module):
|
|
328
|
+
model_params = nnx.state(model_for_loading)
|
|
329
|
+
with jax.default_device(jax.devices("cpu")[0]):
|
|
330
|
+
for loaded_name, loaded_weight in self.names_and_weights_generator:
|
|
331
|
+
if loaded_name.endswith(".bias"):
|
|
332
|
+
continue
|
|
333
|
+
if "vision_model" in loaded_name or "multi_modal_projector" in loaded_name:
|
|
334
|
+
continue
|
|
335
|
+
|
|
336
|
+
mapped_name = self.map_loaded_to_standardized_name(loaded_name)
|
|
337
|
+
model_weight = get_param(model_params, mapped_name)
|
|
338
|
+
|
|
339
|
+
if not loaded_name.endswith(".bias"):
|
|
340
|
+
# For other layers, continue to use the transpose_params helper.
|
|
341
|
+
loaded_weight = reshape_params(loaded_name, loaded_weight,
|
|
342
|
+
self._weight_shape_map)
|
|
343
|
+
loaded_weight = transpose_params(loaded_name,
|
|
344
|
+
loaded_weight,
|
|
345
|
+
self._transpose_map)
|
|
346
|
+
if model_weight.value.shape != loaded_weight.shape:
|
|
347
|
+
raise ValueError(
|
|
348
|
+
f"Loaded shape for {loaded_name}: {loaded_weight.shape} "
|
|
349
|
+
f"does not match model shape for {mapped_name}: {model_weight.value.shape}!"
|
|
350
|
+
)
|
|
351
|
+
logger.debug(
|
|
352
|
+
f"Transformed parameter {loaded_name} to {mapped_name}: {loaded_weight.shape} --> {model_weight.value.shape}"
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
model_weight.value = shard_put(loaded_weight,
|
|
356
|
+
model_weight.sharding,
|
|
357
|
+
mesh=model_for_loading.mesh)
|
|
358
|
+
if self.is_verbose:
|
|
359
|
+
print_param_info(model_weight, loaded_name)
|
|
360
|
+
|
|
361
|
+
nnx.update(model_for_loading, model_params)
|
|
@@ -368,7 +368,8 @@ class Qwen2ForCausalLM(nnx.Module):
|
|
|
368
368
|
"lm_head": "model.lm_head",
|
|
369
369
|
})
|
|
370
370
|
|
|
371
|
-
metadata_map = get_default_maps(self.vllm_config
|
|
371
|
+
metadata_map = get_default_maps(self.vllm_config.model_config,
|
|
372
|
+
self.mesh, mappings)
|
|
372
373
|
load_hf_weights(vllm_config=self.vllm_config,
|
|
373
374
|
model=self,
|
|
374
375
|
metadata_map=metadata_map,
|
|
@@ -486,6 +486,11 @@ class Qwen2_5_VisionTransformer(nnx.Module):
|
|
|
486
486
|
dtype=dtype,
|
|
487
487
|
rngs=rngs)
|
|
488
488
|
|
|
489
|
+
additional_config = getattr(vllm_config, "additional_config",
|
|
490
|
+
None) or {}
|
|
491
|
+
self.enable_dynamic_image_sizes = additional_config.get(
|
|
492
|
+
"enable_dynamic_image_sizes", False)
|
|
493
|
+
|
|
489
494
|
def rotary_pos_emb_thw(self, t, h, w):
|
|
490
495
|
hpos_ids, wpos_ids = jnp.indices((h, w))
|
|
491
496
|
hpos_ids = hpos_ids.reshape(
|
|
@@ -579,21 +584,7 @@ class Qwen2_5_VisionTransformer(nnx.Module):
|
|
|
579
584
|
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
|
580
585
|
return max_seqlen, seqlens
|
|
581
586
|
|
|
582
|
-
def
|
|
583
|
-
int]]) -> jax.Array:
|
|
584
|
-
# x: pixel_values: jax.Array
|
|
585
|
-
# """Shape:
|
|
586
|
-
# `(num_patches, num_channels * patch_size * patch_size)`
|
|
587
|
-
# """
|
|
588
|
-
|
|
589
|
-
# grid_thw: image_grid_thw: jax.Array
|
|
590
|
-
# """Shape: `(num_images, 3)`
|
|
591
|
-
# This should be in `(grid_t, grid_h, grid_w)` format.
|
|
592
|
-
# """
|
|
593
|
-
hidden_states = self.patch_embed(x)
|
|
594
|
-
|
|
595
|
-
# num of patches
|
|
596
|
-
seq_len = x.shape[0]
|
|
587
|
+
def compute_aux_arrays(self, grid_thw: tuple[tuple[int, int, int]]):
|
|
597
588
|
# num of images/videoes
|
|
598
589
|
num_grids = len(grid_thw)
|
|
599
590
|
|
|
@@ -638,6 +629,42 @@ class Qwen2_5_VisionTransformer(nnx.Module):
|
|
|
638
629
|
cu_seqlens = jnp.pad(cu_seqlens, ((1, 0), ),
|
|
639
630
|
mode='constant',
|
|
640
631
|
constant_values=0)
|
|
632
|
+
return window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens
|
|
633
|
+
|
|
634
|
+
def pad_inputs(self, x, window_index, rotary_pos_emb, cu_seqlens,
|
|
635
|
+
cu_window_seqlens):
|
|
636
|
+
# padding
|
|
637
|
+
num_patches = int(rotary_pos_emb.shape[0])
|
|
638
|
+
bucket_num_patches = 1 << (num_patches - 1).bit_length()
|
|
639
|
+
num_tokens = window_index.shape[0]
|
|
640
|
+
bucket_num_tokens = bucket_num_patches // self.spatial_merge_unit
|
|
641
|
+
vit_merger_window_size = (self.window_size //
|
|
642
|
+
self.spatial_merge_size // self.patch_size)
|
|
643
|
+
max_windows = (bucket_num_tokens // vit_merger_window_size) + 2
|
|
644
|
+
|
|
645
|
+
rotary_pos_emb = jnp.pad(rotary_pos_emb,
|
|
646
|
+
((0, bucket_num_patches - num_patches),
|
|
647
|
+
(0, 0)))
|
|
648
|
+
window_index = jnp.concatenate([
|
|
649
|
+
window_index,
|
|
650
|
+
jnp.arange(num_tokens, bucket_num_tokens, dtype=jnp.int32)
|
|
651
|
+
])
|
|
652
|
+
cu_window_seqlens = jnp.append(cu_window_seqlens, bucket_num_patches)
|
|
653
|
+
pad_w = max(0, max_windows + 1 - cu_window_seqlens.shape[0])
|
|
654
|
+
cu_window_seqlens = jnp.pad(cu_window_seqlens, (0, pad_w), mode='edge')
|
|
655
|
+
cu_seqlens = jnp.append(cu_seqlens, bucket_num_patches)
|
|
656
|
+
|
|
657
|
+
x_padded = jnp.pad(x, ((0, bucket_num_patches - x.shape[0]), (0, 0)))
|
|
658
|
+
|
|
659
|
+
return x_padded, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens, num_tokens
|
|
660
|
+
|
|
661
|
+
def compute_hidden_states(self, x: jax.Array, window_index: jax.Array,
|
|
662
|
+
rotary_pos_emb: jax.Array, cu_seqlens: jax.Array,
|
|
663
|
+
cu_window_seqlens: jax.Array) -> jax.Array:
|
|
664
|
+
hidden_states = self.patch_embed(x)
|
|
665
|
+
|
|
666
|
+
# num of patches
|
|
667
|
+
seq_len = x.shape[0]
|
|
641
668
|
|
|
642
669
|
hidden_states = hidden_states.reshape(
|
|
643
670
|
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
|
@@ -664,6 +691,48 @@ class Qwen2_5_VisionTransformer(nnx.Module):
|
|
|
664
691
|
hidden_states = hidden_states[reverse_indices, :]
|
|
665
692
|
return hidden_states
|
|
666
693
|
|
|
694
|
+
@jax.jit
|
|
695
|
+
def encode_padded_jit(self, x_padded, window_index, rotary_pos_emb,
|
|
696
|
+
cu_seqlens, cu_window_seqlens):
|
|
697
|
+
return self.compute_hidden_states(x_padded, window_index,
|
|
698
|
+
rotary_pos_emb, cu_seqlens,
|
|
699
|
+
cu_window_seqlens)
|
|
700
|
+
|
|
701
|
+
@partial(
|
|
702
|
+
jax.jit,
|
|
703
|
+
static_argnames=("grid_thw", ),
|
|
704
|
+
)
|
|
705
|
+
def encode_jit(self, x, grid_thw):
|
|
706
|
+
window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens = self.compute_aux_arrays(
|
|
707
|
+
grid_thw)
|
|
708
|
+
return self.compute_hidden_states(x, window_index, rotary_pos_emb,
|
|
709
|
+
cu_seqlens, cu_window_seqlens)
|
|
710
|
+
|
|
711
|
+
def __call__(self, x: jax.Array, grid_thw: tuple[tuple[int, int,
|
|
712
|
+
int]]) -> jax.Array:
|
|
713
|
+
# x: pixel_values: jax.Array
|
|
714
|
+
# """Shape:
|
|
715
|
+
# `(num_patches, num_channels * patch_size * patch_size)`
|
|
716
|
+
# """
|
|
717
|
+
|
|
718
|
+
# grid_thw: image_grid_thw: jax.Array
|
|
719
|
+
# """Shape: `(num_images, 3)`
|
|
720
|
+
# This should be in `(grid_t, grid_h, grid_w)` format.
|
|
721
|
+
# """
|
|
722
|
+
if self.enable_dynamic_image_sizes:
|
|
723
|
+
window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens = self.compute_aux_arrays(
|
|
724
|
+
grid_thw)
|
|
725
|
+
x_padded, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens, num_tokens = self.pad_inputs(
|
|
726
|
+
x, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens)
|
|
727
|
+
|
|
728
|
+
hidden_states = self.encode_padded_jit(x_padded, window_index,
|
|
729
|
+
rotary_pos_emb, cu_seqlens,
|
|
730
|
+
cu_window_seqlens)
|
|
731
|
+
return hidden_states[:num_tokens]
|
|
732
|
+
|
|
733
|
+
else:
|
|
734
|
+
return self.encode_jit(x, grid_thw)
|
|
735
|
+
|
|
667
736
|
|
|
668
737
|
class Qwen2_5_VLForConditionalGeneration(nnx.Module):
|
|
669
738
|
|
|
@@ -888,10 +957,6 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
|
|
|
888
957
|
# "video"] = self._parse_and_validate_video_input(**kwargs)
|
|
889
958
|
return mm_input_by_modality
|
|
890
959
|
|
|
891
|
-
@partial(
|
|
892
|
-
jax.jit,
|
|
893
|
-
static_argnames=("image_grid_thw", ),
|
|
894
|
-
)
|
|
895
960
|
def get_single_image_embedding(self, image_pixel_values, image_grid_thw):
|
|
896
961
|
return self.visual(image_pixel_values, (image_grid_thw, ))
|
|
897
962
|
|
|
@@ -1061,7 +1126,8 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
|
|
|
1061
1126
|
"lm_head": "language_model.model.lm_head",
|
|
1062
1127
|
})
|
|
1063
1128
|
|
|
1064
|
-
metadata_map = get_default_maps(self.vllm_config
|
|
1129
|
+
metadata_map = get_default_maps(self.vllm_config.model_config,
|
|
1130
|
+
self.mesh, mappings)
|
|
1065
1131
|
load_hf_weights(vllm_config=self.vllm_config,
|
|
1066
1132
|
model=self,
|
|
1067
1133
|
metadata_map=metadata_map,
|
|
@@ -1071,33 +1137,82 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
|
|
|
1071
1137
|
self,
|
|
1072
1138
|
run_compilation_fn: Callable,
|
|
1073
1139
|
) -> None:
|
|
1074
|
-
image_shapes = []
|
|
1075
|
-
if (warmup_config := self.vllm_config.additional_config.get(
|
|
1076
|
-
"vision_warmup_config")):
|
|
1077
|
-
image_shapes = warmup_config.get("image_shapes")
|
|
1078
|
-
|
|
1079
1140
|
vc = self.vllm_config.model_config.hf_config.vision_config
|
|
1080
|
-
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1141
|
+
patch_input_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
|
|
1142
|
+
if self.visual.enable_dynamic_image_sizes:
|
|
1143
|
+
spatial_merge_unit = vc.spatial_merge_size**2
|
|
1144
|
+
max_num_batched_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens
|
|
1145
|
+
mm_kwargs = self.vllm_config.model_config.multimodal_config.mm_processor_kwargs or {}
|
|
1146
|
+
limit_pixels = float(mm_kwargs.get("max_pixels", float('inf')))
|
|
1147
|
+
|
|
1148
|
+
max_patches = int(
|
|
1149
|
+
min(max_num_batched_tokens * spatial_merge_unit,
|
|
1150
|
+
limit_pixels / (vc.patch_size**2)))
|
|
1151
|
+
|
|
1152
|
+
num_patches_paddings = [
|
|
1153
|
+
1 << i for i in range(4, (max_patches - 1).bit_length() + 1)
|
|
1154
|
+
]
|
|
1155
|
+
rotary_dim = vc.hidden_size // vc.num_heads // 2
|
|
1156
|
+
vit_merger_window_size = (vc.window_size //
|
|
1157
|
+
vc.spatial_merge_size // vc.patch_size)
|
|
1158
|
+
|
|
1159
|
+
for num_patches in num_patches_paddings:
|
|
1160
|
+
dummy_x_padded = jnp.ones(
|
|
1161
|
+
(num_patches, patch_input_dim),
|
|
1162
|
+
dtype=self.vllm_config.model_config.dtype)
|
|
1163
|
+
|
|
1164
|
+
num_tokens = num_patches // spatial_merge_unit
|
|
1165
|
+
dummy_window_index = jnp.arange(num_tokens, dtype=jnp.int32)
|
|
1166
|
+
|
|
1167
|
+
dummy_rotary_pos_emb = jnp.ones(
|
|
1168
|
+
(num_patches, rotary_dim),
|
|
1169
|
+
dtype=self.vllm_config.model_config.dtype)
|
|
1170
|
+
|
|
1171
|
+
dummy_cu_seqlens = jnp.array([0, num_patches, num_patches],
|
|
1172
|
+
dtype=jnp.int32)
|
|
1173
|
+
|
|
1174
|
+
max_windows = (num_tokens // vit_merger_window_size) + 2
|
|
1175
|
+
patches_per_window = (vit_merger_window_size**
|
|
1176
|
+
2) * spatial_merge_unit
|
|
1177
|
+
dummy_cu_window_seqlens = jnp.arange(
|
|
1178
|
+
max_windows + 1, dtype=jnp.int32) * patches_per_window
|
|
1179
|
+
dummy_cu_window_seqlens = jnp.minimum(dummy_cu_window_seqlens,
|
|
1180
|
+
num_patches)
|
|
1181
|
+
|
|
1182
|
+
run_compilation_fn("vision_encoder_padded",
|
|
1183
|
+
self.visual.encode_padded_jit,
|
|
1184
|
+
dummy_x_padded,
|
|
1185
|
+
dummy_window_index,
|
|
1186
|
+
dummy_rotary_pos_emb,
|
|
1187
|
+
dummy_cu_seqlens,
|
|
1188
|
+
dummy_cu_window_seqlens,
|
|
1189
|
+
num_patches=num_patches)
|
|
1190
|
+
else:
|
|
1191
|
+
image_shapes = []
|
|
1192
|
+
if (warmup_config := self.vllm_config.additional_config.get(
|
|
1193
|
+
"vision_warmup_config")):
|
|
1194
|
+
image_shapes = warmup_config.get("image_shapes")
|
|
1195
|
+
|
|
1196
|
+
factor = vc.patch_size * vc.spatial_merge_size
|
|
1197
|
+
for input_hw in image_shapes:
|
|
1198
|
+
if not isinstance(input_hw, list) or len(input_hw) != 2:
|
|
1199
|
+
logger.warning(f"Skipping invalid shape {input_hw}.")
|
|
1200
|
+
continue
|
|
1201
|
+
h_input, w_input = input_hw
|
|
1202
|
+
h_processed = round(h_input / factor) * factor
|
|
1203
|
+
w_processed = round(w_input / factor) * factor
|
|
1204
|
+
t, h, w = 1, h_processed // vc.patch_size, w_processed // vc.patch_size
|
|
1205
|
+
grid_thw = (t, h, w)
|
|
1206
|
+
num_patches = t * h * w
|
|
1207
|
+
|
|
1208
|
+
dummy_pixel_values = jnp.ones(
|
|
1209
|
+
(num_patches, patch_input_dim),
|
|
1210
|
+
self.vllm_config.model_config.dtype,
|
|
1211
|
+
)
|
|
1212
|
+
dummy_grid_thw = (grid_thw, )
|
|
1098
1213
|
|
|
1099
|
-
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
1214
|
+
run_compilation_fn("vision_encoder",
|
|
1215
|
+
self.visual.encode_jit,
|
|
1216
|
+
dummy_pixel_values,
|
|
1217
|
+
dummy_grid_thw,
|
|
1218
|
+
image_shape=input_hw)
|
|
@@ -295,7 +295,8 @@ class Qwen3ForCausalLM(nnx.Module):
|
|
|
295
295
|
"lm_head": "model.lm_head",
|
|
296
296
|
})
|
|
297
297
|
|
|
298
|
-
metadata_map = get_default_maps(self.vllm_config
|
|
298
|
+
metadata_map = get_default_maps(self.vllm_config.model_config,
|
|
299
|
+
self.mesh, mappings)
|
|
299
300
|
load_hf_weights(vllm_config=self.vllm_config,
|
|
300
301
|
model=self,
|
|
301
302
|
metadata_map=metadata_map,
|