tpu-inference 0.11.1.dev202511130813__py3-none-any.whl → 0.11.1.dev202511220812__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tests/test_envs.py +182 -0
- tests/test_utils.py +23 -14
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/core_tpu.py +17 -9
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +2 -3
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +1 -1
- tpu_inference/executors/ray_distributed_executor.py +27 -11
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +110 -64
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +7 -0
- tpu_inference/layers/{jax → common}/attention_interface.py +1 -1
- tpu_inference/layers/common/quant_methods.py +8 -0
- tpu_inference/layers/jax/attention/attention.py +1 -1
- tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
- tpu_inference/layers/jax/sample/sampling.py +2 -2
- tpu_inference/layers/vllm/attention.py +1 -1
- tpu_inference/layers/vllm/quantization/__init__.py +7 -3
- tpu_inference/layers/vllm/quantization/awq.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -2
- tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
- tpu_inference/layers/vllm/quantization/unquantized.py +4 -3
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +12 -11
- tpu_inference/models/jax/llama3.py +4 -3
- tpu_inference/models/jax/llama_eagle3.py +9 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +3 -2
- tpu_inference/models/jax/qwen2_5_vl.py +4 -3
- tpu_inference/models/jax/qwen3.py +3 -2
- tpu_inference/models/jax/utils/weight_utils.py +21 -8
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -10
- tpu_inference/platforms/tpu_platform.py +17 -7
- tpu_inference/runner/compilation_manager.py +37 -17
- tpu_inference/runner/kv_cache.py +1 -1
- tpu_inference/runner/kv_cache_manager.py +8 -2
- tpu_inference/runner/tpu_runner.py +199 -87
- tpu_inference/spec_decode/jax/eagle3.py +2 -1
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +7 -6
- tpu_inference/worker/tpu_worker.py +159 -23
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/METADATA +2 -2
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/RECORD +52 -54
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- /tpu_inference/layers/{jax → common}/binary_search.py +0 -0
- /tpu_inference/layers/{jax → common}/sharding.py +0 -0
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,361 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import Any, List, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import torch
|
|
7
|
+
from flax import nnx
|
|
8
|
+
from flax.typing import PRNGKey
|
|
9
|
+
from jax.sharding import Mesh
|
|
10
|
+
from jax.sharding import PartitionSpec as P
|
|
11
|
+
from vllm.config import VllmConfig
|
|
12
|
+
|
|
13
|
+
from tpu_inference.layers.jax.attention.attention import AttentionMetadata
|
|
14
|
+
from tpu_inference.layers.jax.attention.llama4_attention import Llama4Attention
|
|
15
|
+
from tpu_inference.layers.jax.constants import KVCacheType
|
|
16
|
+
from tpu_inference.layers.jax.layers import DenseFFW, Embedder, LMhead, RMSNorm
|
|
17
|
+
from tpu_inference.layers.jax.misc import shard_put
|
|
18
|
+
from tpu_inference.layers.jax.transformer_block import TransformerBlock
|
|
19
|
+
from tpu_inference.logger import init_logger
|
|
20
|
+
from tpu_inference.models.jax.utils.weight_utils import (
|
|
21
|
+
get_param, model_weights_generator, print_param_info, reshape_params,
|
|
22
|
+
transpose_params)
|
|
23
|
+
|
|
24
|
+
logger = init_logger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class LlamaGuard4ForCausalLM(nnx.Module):
|
|
28
|
+
|
|
29
|
+
def __init__(self,
|
|
30
|
+
vllm_config: VllmConfig,
|
|
31
|
+
rng: PRNGKey,
|
|
32
|
+
mesh: Mesh,
|
|
33
|
+
force_random_weights: bool = False):
|
|
34
|
+
logger.warning(
|
|
35
|
+
"🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨\n"
|
|
36
|
+
"Llama Guard 4 (JAX) is WIP: Only the text modality is currently implemented. "
|
|
37
|
+
"Multimodal inputs will fail.\n"
|
|
38
|
+
"🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨")
|
|
39
|
+
assert mesh is not None
|
|
40
|
+
|
|
41
|
+
self.vllm_config = vllm_config
|
|
42
|
+
self.vllm_config.model_config.dtype = torch.bfloat16
|
|
43
|
+
model_config = vllm_config.model_config
|
|
44
|
+
text_config = model_config.hf_config.text_config
|
|
45
|
+
|
|
46
|
+
self.mesh = mesh
|
|
47
|
+
self.is_verbose = getattr(self.vllm_config.additional_config,
|
|
48
|
+
"is_verbose", False)
|
|
49
|
+
|
|
50
|
+
self.use_qk_norm = getattr(text_config, "use_qk_norm", True)
|
|
51
|
+
|
|
52
|
+
vocab_size = model_config.get_vocab_size()
|
|
53
|
+
self.hidden_size = model_config.get_hidden_size()
|
|
54
|
+
|
|
55
|
+
self.dtype: jnp.dtype = jnp.bfloat16
|
|
56
|
+
|
|
57
|
+
self.num_layers: int = getattr(text_config, "num_layers", 48)
|
|
58
|
+
hidden_act: str = getattr(text_config, "hidden_act", "silu")
|
|
59
|
+
|
|
60
|
+
rms_norm_eps = getattr(text_config, "rms_norm_eps", 1e-5)
|
|
61
|
+
self.num_attention_heads = getattr(text_config, "num_attention_heads",
|
|
62
|
+
40)
|
|
63
|
+
self.num_key_value_heads = getattr(text_config, "num_key_value_heads",
|
|
64
|
+
8)
|
|
65
|
+
self.head_dim = getattr(text_config, "head_dim", 128)
|
|
66
|
+
|
|
67
|
+
intermediate_size = getattr(text_config, "intermediate_size", 8192)
|
|
68
|
+
|
|
69
|
+
self.rope_theta_text = getattr(text_config, "rope_theta", 500000.0)
|
|
70
|
+
self.rope_scaling = getattr(text_config, "rope_scaling")
|
|
71
|
+
|
|
72
|
+
self.rng = nnx.Rngs(rng)
|
|
73
|
+
|
|
74
|
+
self.embedder = Embedder(
|
|
75
|
+
vocab_size=vocab_size,
|
|
76
|
+
hidden_size=self.hidden_size,
|
|
77
|
+
dtype=self.dtype,
|
|
78
|
+
vd_sharding=(('data', 'model'), None),
|
|
79
|
+
rngs=self.rng,
|
|
80
|
+
random_init=force_random_weights,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
self.layers = []
|
|
84
|
+
|
|
85
|
+
for i in range(self.num_layers):
|
|
86
|
+
use_attention_rope = True
|
|
87
|
+
|
|
88
|
+
custom_module = DenseFFW(dtype=self.dtype,
|
|
89
|
+
hidden_act=hidden_act,
|
|
90
|
+
hidden_size=self.hidden_size,
|
|
91
|
+
intermediate_size=intermediate_size,
|
|
92
|
+
random_init=force_random_weights,
|
|
93
|
+
rngs=self.rng,
|
|
94
|
+
df_sharding=P(None, 'model'),
|
|
95
|
+
fd_sharding=P('model', None),
|
|
96
|
+
activation_ffw_td=P('data', None))
|
|
97
|
+
|
|
98
|
+
attn = Llama4Attention(
|
|
99
|
+
hidden_size=self.hidden_size,
|
|
100
|
+
dtype=self.dtype,
|
|
101
|
+
num_attention_heads=self.num_attention_heads,
|
|
102
|
+
num_key_value_heads=self.num_key_value_heads,
|
|
103
|
+
head_dim=self.head_dim,
|
|
104
|
+
rope_theta=self.rope_theta_text,
|
|
105
|
+
rope_scaling={
|
|
106
|
+
"scale_factor":
|
|
107
|
+
self.rope_scaling["factor"],
|
|
108
|
+
"low_freq_factor":
|
|
109
|
+
self.rope_scaling["low_freq_factor"],
|
|
110
|
+
"high_freq_factor":
|
|
111
|
+
self.rope_scaling["high_freq_factor"],
|
|
112
|
+
"original_max_position_embeddings":
|
|
113
|
+
self.rope_scaling["original_max_position_embeddings"]
|
|
114
|
+
},
|
|
115
|
+
rngs=self.rng,
|
|
116
|
+
rope_input_ordering="interleaved",
|
|
117
|
+
# TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
|
|
118
|
+
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
|
119
|
+
temperature_tuning=True,
|
|
120
|
+
temperature_tuning_scale=0.1,
|
|
121
|
+
temperature_tuning_floor_scale=8192,
|
|
122
|
+
use_qk_norm=self.use_qk_norm,
|
|
123
|
+
attention_chunk_size=None if use_attention_rope else 8192,
|
|
124
|
+
mesh=self.mesh,
|
|
125
|
+
random_init=force_random_weights,
|
|
126
|
+
activation_attention_td=('data', 'model'),
|
|
127
|
+
activation_q_td=('data', 'model'),
|
|
128
|
+
query_tnh=P('data', 'model', None),
|
|
129
|
+
keyvalue_skh=P('data', 'model', None),
|
|
130
|
+
activation_attention_out_td=('data', 'model'),
|
|
131
|
+
attn_o_tnh=P('data', 'model', None),
|
|
132
|
+
dnh_sharding=(None, 'model', None),
|
|
133
|
+
dkh_sharding=(None, 'model', None),
|
|
134
|
+
nhd_sharding=('model', None, None),
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
pre_attention_norm = RMSNorm(
|
|
138
|
+
dims=self.hidden_size,
|
|
139
|
+
random_init=force_random_weights,
|
|
140
|
+
epsilon=rms_norm_eps,
|
|
141
|
+
rngs=self.rng,
|
|
142
|
+
activation_ffw_td=('data', None),
|
|
143
|
+
with_scale=True,
|
|
144
|
+
dtype=self.dtype,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
pre_mlp_norm = RMSNorm(
|
|
148
|
+
dims=self.hidden_size,
|
|
149
|
+
activation_ffw_td=('data', None),
|
|
150
|
+
epsilon=rms_norm_eps,
|
|
151
|
+
rngs=self.rng,
|
|
152
|
+
with_scale=True,
|
|
153
|
+
dtype=self.dtype,
|
|
154
|
+
random_init=force_random_weights,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
block = TransformerBlock(custom_module=custom_module,
|
|
158
|
+
attn=attn,
|
|
159
|
+
pre_attention_norm=pre_attention_norm,
|
|
160
|
+
pre_mlp_norm=pre_mlp_norm,
|
|
161
|
+
use_attention_rope=use_attention_rope)
|
|
162
|
+
self.layers.append(block)
|
|
163
|
+
|
|
164
|
+
self.final_norm = RMSNorm(
|
|
165
|
+
dims=self.hidden_size,
|
|
166
|
+
activation_ffw_td=P(),
|
|
167
|
+
epsilon=rms_norm_eps,
|
|
168
|
+
rngs=self.rng,
|
|
169
|
+
with_scale=True,
|
|
170
|
+
dtype=self.dtype,
|
|
171
|
+
random_init=force_random_weights,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
self.lm_head = LMhead(vocab_size=vocab_size,
|
|
175
|
+
hidden_size=self.hidden_size,
|
|
176
|
+
dtype=self.dtype,
|
|
177
|
+
rngs=self.rng,
|
|
178
|
+
vd_sharding=(('data', 'model'), None),
|
|
179
|
+
dv_sharding=(None, ('data', 'model')),
|
|
180
|
+
random_init=force_random_weights)
|
|
181
|
+
if self.is_verbose:
|
|
182
|
+
self._print_model_architecture()
|
|
183
|
+
|
|
184
|
+
def _print_model_architecture(self):
|
|
185
|
+
|
|
186
|
+
logger.info("### Embedding ###")
|
|
187
|
+
nnx.display(self.embedder)
|
|
188
|
+
|
|
189
|
+
logger.info("\n### Layers ###")
|
|
190
|
+
for i, layer in enumerate(self.layers):
|
|
191
|
+
logger.info(f"\n--- Layer {i} ---")
|
|
192
|
+
nnx.display(layer)
|
|
193
|
+
|
|
194
|
+
logger.info("\n### LM Head ###")
|
|
195
|
+
nnx.display(self.lm_head)
|
|
196
|
+
|
|
197
|
+
def load_weights(self, rng: jax.Array, cache_dir: Optional[str] = None):
|
|
198
|
+
self.rng = nnx.Rngs(rng)
|
|
199
|
+
|
|
200
|
+
weight_loader = LlamaGuard4WeightLoader(
|
|
201
|
+
vllm_config=self.vllm_config,
|
|
202
|
+
hidden_size=self.hidden_size,
|
|
203
|
+
attn_heads=self.num_attention_heads,
|
|
204
|
+
num_key_value_heads=self.num_key_value_heads,
|
|
205
|
+
attn_head_dim=self.head_dim)
|
|
206
|
+
weight_loader.load_weights(self)
|
|
207
|
+
|
|
208
|
+
def __call__(
|
|
209
|
+
self,
|
|
210
|
+
kv_caches: List[jax.Array],
|
|
211
|
+
input_ids: jax.Array,
|
|
212
|
+
attention_metadata: AttentionMetadata,
|
|
213
|
+
inputs_embeds: Optional[jax.Array] = None,
|
|
214
|
+
layer_metadata_tuple: Optional[Tuple] = None,
|
|
215
|
+
lora_metadata: Optional[Any] = None,
|
|
216
|
+
*args,
|
|
217
|
+
) -> Tuple[List[KVCacheType], jax.Array]:
|
|
218
|
+
is_prefill = False
|
|
219
|
+
|
|
220
|
+
if inputs_embeds is not None:
|
|
221
|
+
x_TD = inputs_embeds
|
|
222
|
+
elif input_ids is not None:
|
|
223
|
+
x_TD = self.embedder.encode(input_ids)
|
|
224
|
+
else:
|
|
225
|
+
raise ValueError(
|
|
226
|
+
"Cannot run forward pass: Both input_ids and inputs_embeds are None."
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
for (i, block) in enumerate(self.layers):
|
|
230
|
+
kv_cache = kv_caches[i]
|
|
231
|
+
new_kv_cache, x_TD = block(x_TD, is_prefill, kv_cache,
|
|
232
|
+
attention_metadata)
|
|
233
|
+
jax.block_until_ready(x_TD)
|
|
234
|
+
kv_caches[i] = new_kv_cache
|
|
235
|
+
|
|
236
|
+
final_activation_TD = self.final_norm(x_TD)
|
|
237
|
+
|
|
238
|
+
return kv_caches, final_activation_TD, []
|
|
239
|
+
|
|
240
|
+
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
|
|
241
|
+
logits_TV = jnp.dot(hidden_states,
|
|
242
|
+
self.lm_head.input_embedding_table_DV.value)
|
|
243
|
+
return logits_TV
|
|
244
|
+
|
|
245
|
+
def get_input_embeddings(
|
|
246
|
+
self,
|
|
247
|
+
input_ids: jax.Array,
|
|
248
|
+
multimodal_embeddings: Optional[List[jax.Array]] = None
|
|
249
|
+
) -> jax.Array:
|
|
250
|
+
"""
|
|
251
|
+
Computes the embeddings for text input (used for input to fusion).
|
|
252
|
+
"""
|
|
253
|
+
return self.embedder.encode(input_ids)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
class LlamaGuard4WeightLoader:
|
|
257
|
+
|
|
258
|
+
def __init__(self, vllm_config: VllmConfig, hidden_size, attn_heads,
|
|
259
|
+
num_key_value_heads, attn_head_dim):
|
|
260
|
+
self.names_and_weights_generator = model_weights_generator(
|
|
261
|
+
model_name_or_path=vllm_config.model_config.model,
|
|
262
|
+
framework="flax",
|
|
263
|
+
filter_regex="language_model",
|
|
264
|
+
download_dir=vllm_config.load_config.download_dir)
|
|
265
|
+
self.is_verbose = getattr(vllm_config.additional_config, "is_verbose",
|
|
266
|
+
False)
|
|
267
|
+
self._transpose_map = {
|
|
268
|
+
"q_proj": (2, 0, 1),
|
|
269
|
+
"k_proj": (2, 0, 1),
|
|
270
|
+
"v_proj": (2, 0, 1),
|
|
271
|
+
"o_proj": (1, 2, 0),
|
|
272
|
+
"lm_head": (1, 0),
|
|
273
|
+
"feed_forward.down_proj": (1, 0),
|
|
274
|
+
"feed_forward.gate_proj": (1, 0),
|
|
275
|
+
"feed_forward.up_proj": (1, 0),
|
|
276
|
+
"mlp.down_proj": (1, 0),
|
|
277
|
+
"mlp.gate_proj": (1, 0),
|
|
278
|
+
"mlp.up_proj": (1, 0),
|
|
279
|
+
}
|
|
280
|
+
self._weight_shape_map = {
|
|
281
|
+
"q_proj": (attn_heads, attn_head_dim, hidden_size),
|
|
282
|
+
"k_proj": (num_key_value_heads, attn_head_dim, hidden_size),
|
|
283
|
+
"v_proj": (num_key_value_heads, attn_head_dim, hidden_size),
|
|
284
|
+
"o_proj": (hidden_size, attn_heads, attn_head_dim),
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
self._loaded_to_standardized_keys = {
|
|
288
|
+
"language_model.model.embed_tokens.weight":
|
|
289
|
+
"embedder.input_embedding_table_VD",
|
|
290
|
+
"language_model.lm_head.weight":
|
|
291
|
+
"lm_head.input_embedding_table_DV",
|
|
292
|
+
"language_model.model.norm.weight":
|
|
293
|
+
"final_norm.scale",
|
|
294
|
+
"language_model.model.layers.*.input_layernorm.weight":
|
|
295
|
+
"layers.*.pre_attention_norm.scale",
|
|
296
|
+
"language_model.model.layers.*.post_attention_layernorm.weight":
|
|
297
|
+
"layers.*.pre_mlp_norm.scale",
|
|
298
|
+
"language_model.model.layers.*.self_attn.q_proj.weight":
|
|
299
|
+
"layers.*.attn.kernel_q_proj_DNH",
|
|
300
|
+
"language_model.model.layers.*.self_attn.k_proj.weight":
|
|
301
|
+
"layers.*.attn.kernel_k_proj_DKH",
|
|
302
|
+
"language_model.model.layers.*.self_attn.v_proj.weight":
|
|
303
|
+
"layers.*.attn.kernel_v_proj_DKH",
|
|
304
|
+
"language_model.model.layers.*.self_attn.o_proj.weight":
|
|
305
|
+
"layers.*.attn.kernel_o_proj_NHD",
|
|
306
|
+
"language_model.model.layers.*.feed_forward.gate_proj.weight":
|
|
307
|
+
"layers.*.custom_module.kernel_gating_DF",
|
|
308
|
+
"language_model.model.layers.*.feed_forward.up_proj.weight":
|
|
309
|
+
"layers.*.custom_module.kernel_up_proj_DF",
|
|
310
|
+
"language_model.model.layers.*.feed_forward.down_proj.weight":
|
|
311
|
+
"layers.*.custom_module.kernel_down_proj_FD",
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
def map_loaded_to_standardized_name(self, loaded_key: str) -> str:
|
|
315
|
+
if "layer" in loaded_key:
|
|
316
|
+
layer_num = re.search(r"layers\.(\d+)", loaded_key).group(1)
|
|
317
|
+
layer_key = re.sub(r"layers\.\d+", "layers.*", loaded_key)
|
|
318
|
+
mapped_key = self._loaded_to_standardized_keys.get(
|
|
319
|
+
layer_key, loaded_key)
|
|
320
|
+
mapped_key = re.sub(r"layers\.\*", f"layers.{layer_num}",
|
|
321
|
+
mapped_key)
|
|
322
|
+
else:
|
|
323
|
+
mapped_key = self._loaded_to_standardized_keys.get(
|
|
324
|
+
loaded_key, loaded_key)
|
|
325
|
+
return mapped_key
|
|
326
|
+
|
|
327
|
+
def load_weights(self, model_for_loading: nnx.Module):
|
|
328
|
+
model_params = nnx.state(model_for_loading)
|
|
329
|
+
with jax.default_device(jax.devices("cpu")[0]):
|
|
330
|
+
for loaded_name, loaded_weight in self.names_and_weights_generator:
|
|
331
|
+
if loaded_name.endswith(".bias"):
|
|
332
|
+
continue
|
|
333
|
+
if "vision_model" in loaded_name or "multi_modal_projector" in loaded_name:
|
|
334
|
+
continue
|
|
335
|
+
|
|
336
|
+
mapped_name = self.map_loaded_to_standardized_name(loaded_name)
|
|
337
|
+
model_weight = get_param(model_params, mapped_name)
|
|
338
|
+
|
|
339
|
+
if not loaded_name.endswith(".bias"):
|
|
340
|
+
# For other layers, continue to use the transpose_params helper.
|
|
341
|
+
loaded_weight = reshape_params(loaded_name, loaded_weight,
|
|
342
|
+
self._weight_shape_map)
|
|
343
|
+
loaded_weight = transpose_params(loaded_name,
|
|
344
|
+
loaded_weight,
|
|
345
|
+
self._transpose_map)
|
|
346
|
+
if model_weight.value.shape != loaded_weight.shape:
|
|
347
|
+
raise ValueError(
|
|
348
|
+
f"Loaded shape for {loaded_name}: {loaded_weight.shape} "
|
|
349
|
+
f"does not match model shape for {mapped_name}: {model_weight.value.shape}!"
|
|
350
|
+
)
|
|
351
|
+
logger.debug(
|
|
352
|
+
f"Transformed parameter {loaded_name} to {mapped_name}: {loaded_weight.shape} --> {model_weight.value.shape}"
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
model_weight.value = shard_put(loaded_weight,
|
|
356
|
+
model_weight.sharding,
|
|
357
|
+
mesh=model_for_loading.mesh)
|
|
358
|
+
if self.is_verbose:
|
|
359
|
+
print_param_info(model_weight, loaded_name)
|
|
360
|
+
|
|
361
|
+
nnx.update(model_for_loading, model_params)
|
|
@@ -8,8 +8,8 @@ from transformers import Qwen2Config, modeling_flax_utils
|
|
|
8
8
|
from vllm.config import VllmConfig
|
|
9
9
|
|
|
10
10
|
from tpu_inference import utils
|
|
11
|
+
from tpu_inference.layers.common.attention_interface import attention
|
|
11
12
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
12
|
-
from tpu_inference.layers.jax.attention_interface import attention
|
|
13
13
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
14
14
|
from tpu_inference.logger import init_logger
|
|
15
15
|
from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
|
|
@@ -368,7 +368,8 @@ class Qwen2ForCausalLM(nnx.Module):
|
|
|
368
368
|
"lm_head": "model.lm_head",
|
|
369
369
|
})
|
|
370
370
|
|
|
371
|
-
metadata_map = get_default_maps(self.vllm_config
|
|
371
|
+
metadata_map = get_default_maps(self.vllm_config.model_config,
|
|
372
|
+
self.mesh, mappings)
|
|
372
373
|
load_hf_weights(vllm_config=self.vllm_config,
|
|
373
374
|
model=self,
|
|
374
375
|
metadata_map=metadata_map,
|
|
@@ -14,9 +14,9 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
|
|
14
14
|
from vllm.config import VllmConfig
|
|
15
15
|
|
|
16
16
|
from tpu_inference import utils as utils
|
|
17
|
-
from tpu_inference.layers.common.
|
|
18
|
-
from tpu_inference.layers.jax.attention_interface import \
|
|
17
|
+
from tpu_inference.layers.common.attention_interface import \
|
|
19
18
|
sharded_flash_attention
|
|
19
|
+
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
20
20
|
from tpu_inference.logger import init_logger
|
|
21
21
|
from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
|
|
22
22
|
# from vllm.model_executor.models.interfaces import MultiModalEmbeddings
|
|
@@ -1061,7 +1061,8 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
|
|
|
1061
1061
|
"lm_head": "language_model.model.lm_head",
|
|
1062
1062
|
})
|
|
1063
1063
|
|
|
1064
|
-
metadata_map = get_default_maps(self.vllm_config
|
|
1064
|
+
metadata_map = get_default_maps(self.vllm_config.model_config,
|
|
1065
|
+
self.mesh, mappings)
|
|
1065
1066
|
load_hf_weights(vllm_config=self.vllm_config,
|
|
1066
1067
|
model=self,
|
|
1067
1068
|
metadata_map=metadata_map,
|
|
@@ -8,8 +8,8 @@ from transformers import Qwen3Config
|
|
|
8
8
|
from vllm.config import VllmConfig
|
|
9
9
|
|
|
10
10
|
from tpu_inference import utils
|
|
11
|
+
from tpu_inference.layers.common.attention_interface import attention
|
|
11
12
|
from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
12
|
-
from tpu_inference.layers.jax.attention_interface import attention
|
|
13
13
|
from tpu_inference.layers.jax.rope_interface import apply_rope
|
|
14
14
|
from tpu_inference.logger import init_logger
|
|
15
15
|
from tpu_inference.models.jax.qwen2 import Qwen2DecoderLayer
|
|
@@ -295,7 +295,8 @@ class Qwen3ForCausalLM(nnx.Module):
|
|
|
295
295
|
"lm_head": "model.lm_head",
|
|
296
296
|
})
|
|
297
297
|
|
|
298
|
-
metadata_map = get_default_maps(self.vllm_config
|
|
298
|
+
metadata_map = get_default_maps(self.vllm_config.model_config,
|
|
299
|
+
self.mesh, mappings)
|
|
299
300
|
load_hf_weights(vllm_config=self.vllm_config,
|
|
300
301
|
model=self,
|
|
301
302
|
metadata_map=metadata_map,
|
|
@@ -18,7 +18,7 @@ from jax.sharding import Mesh, NamedSharding
|
|
|
18
18
|
from jax.sharding import PartitionSpec as P
|
|
19
19
|
from safetensors import safe_open
|
|
20
20
|
|
|
21
|
-
from tpu_inference import utils
|
|
21
|
+
from tpu_inference import envs, utils
|
|
22
22
|
from tpu_inference.logger import init_logger
|
|
23
23
|
from tpu_inference.models.jax.utils import file_utils
|
|
24
24
|
|
|
@@ -197,12 +197,11 @@ def shard_put(x: jax.Array, shardings, mesh: jax.sharding.Mesh) -> jax.Array:
|
|
|
197
197
|
return jax.device_put(x, shardings)
|
|
198
198
|
|
|
199
199
|
|
|
200
|
-
def get_default_maps(
|
|
200
|
+
def get_default_maps(model_config, mesh: Mesh,
|
|
201
201
|
name_map: dict[str, str]) -> MetadataMap:
|
|
202
202
|
"""Load weights from one model weights file to the model, run on single thread."""
|
|
203
203
|
sharding_size = mesh.shape["model"]
|
|
204
204
|
|
|
205
|
-
model_config = vllm_config.model_config
|
|
206
205
|
hf_config = model_config.hf_config
|
|
207
206
|
|
|
208
207
|
num_heads = hf_config.num_attention_heads
|
|
@@ -273,7 +272,8 @@ def _load_hf_weights_on_thread(vllm_config,
|
|
|
273
272
|
weights_file: str,
|
|
274
273
|
filter_regex: str | None = None,
|
|
275
274
|
keep_original_dtype_keys_regex: list[str]
|
|
276
|
-
| None = None
|
|
275
|
+
| None = None,
|
|
276
|
+
exclude_regex: list[str] | None = None):
|
|
277
277
|
name_map = metadata_map.name_map
|
|
278
278
|
reshape_keys = metadata_map.reshape_map
|
|
279
279
|
bias_reshape_keys = metadata_map.bias_reshape_map
|
|
@@ -298,6 +298,18 @@ def _load_hf_weights_on_thread(vllm_config,
|
|
|
298
298
|
for hf_key, hf_weight in model_weights_single_file_generator(
|
|
299
299
|
weights_file, framework="flax", filter_regex=filter_regex):
|
|
300
300
|
|
|
301
|
+
# Check if the key should be excluded
|
|
302
|
+
if exclude_regex:
|
|
303
|
+
should_exclude = False
|
|
304
|
+
for pattern in exclude_regex:
|
|
305
|
+
if re.search(pattern, hf_key):
|
|
306
|
+
logger.info(
|
|
307
|
+
f"Excluding {hf_key} based on pattern {pattern}")
|
|
308
|
+
should_exclude = True
|
|
309
|
+
break
|
|
310
|
+
if should_exclude:
|
|
311
|
+
continue
|
|
312
|
+
|
|
301
313
|
# Check if the key should retain its original dtype
|
|
302
314
|
keep_original_dtype = False
|
|
303
315
|
if keep_original_dtype_keys_regex:
|
|
@@ -408,7 +420,8 @@ def load_hf_weights(vllm_config,
|
|
|
408
420
|
mesh: Mesh,
|
|
409
421
|
filter_regex: str | None = None,
|
|
410
422
|
is_draft_model: bool = False,
|
|
411
|
-
keep_original_dtype_keys_regex: list[str] | None = None
|
|
423
|
+
keep_original_dtype_keys_regex: list[str] | None = None,
|
|
424
|
+
exclude_regex: list[str] | None = None):
|
|
412
425
|
"""Load weights from all model weights files to the model, run in multi threads."""
|
|
413
426
|
if is_draft_model:
|
|
414
427
|
model_path = vllm_config.speculative_config.draft_model_config.model
|
|
@@ -421,7 +434,7 @@ def load_hf_weights(vllm_config,
|
|
|
421
434
|
# NOTE(xiang): Disable multi-threading mode if running on multi-host.
|
|
422
435
|
# Because multi-threading would cause different JAX processes to load
|
|
423
436
|
# different weights at the same time.
|
|
424
|
-
if
|
|
437
|
+
if envs.TPU_MULTIHOST_BACKEND == "ray":
|
|
425
438
|
max_workers = 1
|
|
426
439
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
427
440
|
futures = [
|
|
@@ -433,8 +446,8 @@ def load_hf_weights(vllm_config,
|
|
|
433
446
|
mesh,
|
|
434
447
|
weights_file,
|
|
435
448
|
filter_regex=filter_regex,
|
|
436
|
-
keep_original_dtype_keys_regex=keep_original_dtype_keys_regex
|
|
437
|
-
|
|
449
|
+
keep_original_dtype_keys_regex=keep_original_dtype_keys_regex,
|
|
450
|
+
exclude_regex=exclude_regex) for weights_file in weights_files
|
|
438
451
|
]
|
|
439
452
|
for future in futures:
|
|
440
453
|
future.result()
|
|
@@ -25,6 +25,8 @@ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
|
|
|
25
25
|
from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
|
|
26
26
|
from tpu_inference.layers.vllm.sharding import shard_model_to_tpu
|
|
27
27
|
from tpu_inference.logger import init_logger
|
|
28
|
+
from tpu_inference.models.jax.jax_intermediate_tensor import \
|
|
29
|
+
JaxIntermediateTensors
|
|
28
30
|
from tpu_inference.models.vllm.vllm_model_wrapper_context import (
|
|
29
31
|
get_vllm_model_wrapper_context, set_vllm_model_wrapper_context)
|
|
30
32
|
from tpu_inference.runner.lora_utils import replace_lora_metadata
|
|
@@ -89,13 +91,14 @@ class VllmModelWrapper:
|
|
|
89
91
|
slice_config = self.vllm_config.device_config.slice
|
|
90
92
|
modified_slice_config = True
|
|
91
93
|
self.vllm_config.device_config.slice = None
|
|
94
|
+
self.vllm_config.compilation_config.static_forward_context.clear()
|
|
95
|
+
|
|
92
96
|
vllm_config_for_load = copy.deepcopy(self.vllm_config)
|
|
93
97
|
if modified_slice_config:
|
|
94
98
|
self.vllm_config.device_config.slice = slice_config
|
|
95
99
|
assert self.vllm_config.model_config.dtype in TORCH_DTYPE_TO_JAX, "The model_config.dtype must be a PyTorch dtype."
|
|
96
100
|
vllm_config_for_load.device_config.device = "cpu"
|
|
97
101
|
# Clearing the cached compilation config, otherwise vllm model init will fail
|
|
98
|
-
vllm_config_for_load.compilation_config.static_forward_context.clear()
|
|
99
102
|
|
|
100
103
|
# When expert parallelism is enabled, vLLM loads weight in sharding
|
|
101
104
|
# aware manner. Since tpu-inference has its own sharding logic, this
|
|
@@ -117,7 +120,7 @@ class VllmModelWrapper:
|
|
|
117
120
|
|
|
118
121
|
# Load the vLLM model and wrap it into a new model whose forward
|
|
119
122
|
# function can calculate the hidden_state and logits.
|
|
120
|
-
with load_context:
|
|
123
|
+
with load_context, jax.default_device(jax.devices("cpu")[0]):
|
|
121
124
|
vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
|
|
122
125
|
lora_manager = None
|
|
123
126
|
if vllm_config_for_load.lora_config is not None:
|
|
@@ -149,7 +152,8 @@ class VllmModelWrapper:
|
|
|
149
152
|
"xla_tpu_reduce_scatter_collective_matmul_mode":
|
|
150
153
|
"post_spmd_conservative"
|
|
151
154
|
},
|
|
152
|
-
static_argnames=("layer_name_to_kvcache_index",
|
|
155
|
+
static_argnames=("layer_name_to_kvcache_index", "is_first_rank",
|
|
156
|
+
"is_last_rank"),
|
|
153
157
|
)
|
|
154
158
|
def step_fun(
|
|
155
159
|
params_and_buffers, # This has been wrapped into torchax TorchValue
|
|
@@ -157,8 +161,12 @@ class VllmModelWrapper:
|
|
|
157
161
|
input_ids: jax.Array,
|
|
158
162
|
attn_metadata: AttentionMetadata,
|
|
159
163
|
input_embeds: jax.Array,
|
|
164
|
+
input_positions: jax.Array,
|
|
160
165
|
layer_name_to_kvcache_index: Sequence[Tuple[str, int]],
|
|
161
166
|
lora_metadata,
|
|
167
|
+
intermediate_tensors: JaxIntermediateTensors = None,
|
|
168
|
+
is_first_rank: bool = True,
|
|
169
|
+
is_last_rank: bool = True,
|
|
162
170
|
*args,
|
|
163
171
|
) -> Tuple[List[jax.Array], jax.Array]:
|
|
164
172
|
layer_name_to_kvcache_index = dict(layer_name_to_kvcache_index)
|
|
@@ -173,12 +181,14 @@ class VllmModelWrapper:
|
|
|
173
181
|
# torch_view in order to call the Torch function.
|
|
174
182
|
original_lora_metadata = replace_lora_metadata(
|
|
175
183
|
self.model, lora_metadata, self.vllm_config.lora_config)
|
|
176
|
-
|
|
184
|
+
if not is_first_rank:
|
|
185
|
+
intermediate_tensors = intermediate_tensors.to_torch()
|
|
186
|
+
output_from_torch = torch.func.functional_call(
|
|
177
187
|
self.model,
|
|
178
188
|
torch_view(params_and_buffers),
|
|
179
189
|
kwargs={
|
|
180
190
|
"input_ids": torch_view(input_ids),
|
|
181
|
-
"positions": torch_view(
|
|
191
|
+
"positions": torch_view(input_positions),
|
|
182
192
|
"intermediate_tensors": None,
|
|
183
193
|
"inputs_embeds": None,
|
|
184
194
|
},
|
|
@@ -188,11 +198,13 @@ class VllmModelWrapper:
|
|
|
188
198
|
self.vllm_config.lora_config)
|
|
189
199
|
vllm_model_wrapper_context = get_vllm_model_wrapper_context()
|
|
190
200
|
new_kv_caches = vllm_model_wrapper_context.kv_caches
|
|
191
|
-
# Wrap the
|
|
192
|
-
# code to consume.
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
201
|
+
# Wrap the output(hidden states or intermediate tensor)
|
|
202
|
+
# from torch land into a JaxValue for the jax code to consume.
|
|
203
|
+
if not is_last_rank:
|
|
204
|
+
output = JaxIntermediateTensors.from_torch(output_from_torch)
|
|
205
|
+
else:
|
|
206
|
+
output = jax_view(output_from_torch)
|
|
207
|
+
return new_kv_caches, output, []
|
|
196
208
|
|
|
197
209
|
return step_fun
|
|
198
210
|
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
|
|
3
|
-
import
|
|
4
|
-
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
|
|
5
4
|
|
|
6
5
|
import jax.numpy as jnp
|
|
7
6
|
import vllm.envs as vllm_envs
|
|
@@ -12,7 +11,7 @@ from vllm.platforms.interface import Platform, PlatformEnum
|
|
|
12
11
|
from vllm.sampling_params import SamplingParams, SamplingType
|
|
13
12
|
|
|
14
13
|
from tpu_inference import envs
|
|
15
|
-
from tpu_inference.layers.
|
|
14
|
+
from tpu_inference.layers.common.sharding import ShardingConfigManager
|
|
16
15
|
from tpu_inference.logger import init_logger
|
|
17
16
|
|
|
18
17
|
if TYPE_CHECKING:
|
|
@@ -57,7 +56,8 @@ class TpuPlatform(Platform):
|
|
|
57
56
|
def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
|
|
58
57
|
dtype: jnp.dtype, kv_cache_dtype: Optional[str],
|
|
59
58
|
block_size: int, use_v1: bool, use_mla: bool,
|
|
60
|
-
has_sink: bool, use_sparse: bool
|
|
59
|
+
has_sink: bool, use_sparse: bool,
|
|
60
|
+
attn_type: Any) -> str:
|
|
61
61
|
from vllm.attention.backends.registry import _Backend
|
|
62
62
|
if selected_backend != _Backend.PALLAS:
|
|
63
63
|
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
|
@@ -182,10 +182,16 @@ class TpuPlatform(Platform):
|
|
|
182
182
|
parallel_config.worker_cls = \
|
|
183
183
|
"tpu_inference.worker.tpu_worker.TPUWorker"
|
|
184
184
|
|
|
185
|
-
multihost_backend =
|
|
185
|
+
multihost_backend = envs.TPU_MULTIHOST_BACKEND
|
|
186
186
|
if not multihost_backend: # Single host
|
|
187
|
-
|
|
188
|
-
|
|
187
|
+
if parallel_config.pipeline_parallel_size == 1:
|
|
188
|
+
logger.info("Force using UniProcExecutor for JAX on \
|
|
189
|
+
single host without pipeline parallelism.")
|
|
190
|
+
parallel_config.distributed_executor_backend = "uni"
|
|
191
|
+
else:
|
|
192
|
+
logger.info("Force using MultiprocExecutor for JAX on \
|
|
193
|
+
single host with pipeline parallelism.")
|
|
194
|
+
parallel_config.distributed_executor_backend = "mp"
|
|
189
195
|
elif multihost_backend == "ray":
|
|
190
196
|
from tpu_inference.executors.ray_distributed_executor import \
|
|
191
197
|
RayDistributedExecutor
|
|
@@ -260,3 +266,7 @@ class TpuPlatform(Platform):
|
|
|
260
266
|
Returns if the current platform needs to sync weight loader.
|
|
261
267
|
"""
|
|
262
268
|
return True
|
|
269
|
+
|
|
270
|
+
@classmethod
|
|
271
|
+
def support_hybrid_kv_cache(cls) -> bool:
|
|
272
|
+
return True
|