tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.11.1.dev202511220812__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/lora/test_layers.py +0 -6
- tests/lora/utils.py +0 -8
- tpu_inference/__init__.py +22 -3
- tpu_inference/core/disagg_utils.py +6 -8
- tpu_inference/distributed/tpu_connector.py +2 -3
- tpu_inference/distributed/utils.py +3 -2
- tpu_inference/envs.py +1 -1
- tpu_inference/executors/ray_distributed_executor.py +4 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +77 -54
- tpu_inference/layers/vllm/sharding.py +2 -2
- tpu_inference/lora/torch_punica_tpu.py +1 -2
- tpu_inference/models/common/model_loader.py +9 -9
- tpu_inference/models/jax/llama3.py +2 -1
- tpu_inference/models/jax/llama_eagle3.py +9 -5
- tpu_inference/models/jax/llama_guard_4.py +361 -0
- tpu_inference/models/jax/qwen2.py +2 -1
- tpu_inference/models/jax/qwen2_5_vl.py +2 -1
- tpu_inference/models/jax/qwen3.py +2 -1
- tpu_inference/models/jax/utils/weight_utils.py +21 -8
- tpu_inference/models/vllm/vllm_model_wrapper.py +4 -4
- tpu_inference/platforms/tpu_platform.py +5 -2
- tpu_inference/runner/compilation_manager.py +33 -15
- tpu_inference/runner/kv_cache_manager.py +8 -2
- tpu_inference/runner/tpu_runner.py +187 -99
- tpu_inference/spec_decode/jax/eagle3.py +2 -1
- tpu_inference/tpu_info.py +4 -3
- tpu_inference/utils.py +5 -4
- tpu_inference/worker/tpu_worker.py +158 -22
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/METADATA +2 -2
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/RECORD +34 -39
- tpu_inference/mock/__init__.py +0 -0
- tpu_inference/mock/vllm_config_utils.py +0 -28
- tpu_inference/mock/vllm_envs.py +0 -1219
- tpu_inference/mock/vllm_logger.py +0 -212
- tpu_inference/mock/vllm_logging_utils.py +0 -15
- tpu_inference/models/jax/phi3.py +0 -376
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/top_level.txt +0 -0
|
@@ -194,13 +194,12 @@ class Eagle3LlamaModel(nnx.Module):
|
|
|
194
194
|
|
|
195
195
|
def update_reshape_map_for_eagle3(vllm_config: VllmConfig,
|
|
196
196
|
metadata_map: MetadataMap):
|
|
197
|
-
model_config = vllm_config.
|
|
197
|
+
model_config = vllm_config.speculative_config.draft_model_config
|
|
198
198
|
hf_config = model_config.hf_config
|
|
199
199
|
|
|
200
200
|
num_heads = hf_config.num_attention_heads
|
|
201
201
|
num_kv_heads = hf_config.num_key_value_heads
|
|
202
|
-
hidden_size =
|
|
203
|
-
|
|
202
|
+
hidden_size = hf_config.hidden_size
|
|
204
203
|
head_dim_original = model_config.get_head_size()
|
|
205
204
|
|
|
206
205
|
metadata_map.reshape_map.update({
|
|
@@ -312,7 +311,11 @@ class EagleLlama3ForCausalLM(nnx.Module):
|
|
|
312
311
|
r".*d2t.*",
|
|
313
312
|
]
|
|
314
313
|
|
|
315
|
-
|
|
314
|
+
# `embed_tokens` is shared between target and draft.
|
|
315
|
+
exclude_regex = [r".*embed_tokens.*"]
|
|
316
|
+
metadata_map = get_default_maps(
|
|
317
|
+
self.vllm_config.speculative_config.draft_model_config, self.mesh,
|
|
318
|
+
mappings)
|
|
316
319
|
|
|
317
320
|
update_reshape_map_for_eagle3(self.vllm_config, metadata_map)
|
|
318
321
|
|
|
@@ -322,7 +325,8 @@ class EagleLlama3ForCausalLM(nnx.Module):
|
|
|
322
325
|
metadata_map=metadata_map,
|
|
323
326
|
mesh=self.mesh,
|
|
324
327
|
is_draft_model=True,
|
|
325
|
-
keep_original_dtype_keys_regex=keep_original_dtype_keys_regex
|
|
328
|
+
keep_original_dtype_keys_regex=keep_original_dtype_keys_regex,
|
|
329
|
+
exclude_regex=exclude_regex if exclude_regex else None)
|
|
326
330
|
|
|
327
331
|
# If the embedding is not initialized, initialize it with a dummpy array here to pass jit compilation. The real weights will be shared from the target model in eagle3 class.
|
|
328
332
|
if isinstance(self.model.embed_tokens.embedding.value,
|
|
@@ -0,0 +1,361 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import Any, List, Optional, Tuple
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import torch
|
|
7
|
+
from flax import nnx
|
|
8
|
+
from flax.typing import PRNGKey
|
|
9
|
+
from jax.sharding import Mesh
|
|
10
|
+
from jax.sharding import PartitionSpec as P
|
|
11
|
+
from vllm.config import VllmConfig
|
|
12
|
+
|
|
13
|
+
from tpu_inference.layers.jax.attention.attention import AttentionMetadata
|
|
14
|
+
from tpu_inference.layers.jax.attention.llama4_attention import Llama4Attention
|
|
15
|
+
from tpu_inference.layers.jax.constants import KVCacheType
|
|
16
|
+
from tpu_inference.layers.jax.layers import DenseFFW, Embedder, LMhead, RMSNorm
|
|
17
|
+
from tpu_inference.layers.jax.misc import shard_put
|
|
18
|
+
from tpu_inference.layers.jax.transformer_block import TransformerBlock
|
|
19
|
+
from tpu_inference.logger import init_logger
|
|
20
|
+
from tpu_inference.models.jax.utils.weight_utils import (
|
|
21
|
+
get_param, model_weights_generator, print_param_info, reshape_params,
|
|
22
|
+
transpose_params)
|
|
23
|
+
|
|
24
|
+
logger = init_logger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class LlamaGuard4ForCausalLM(nnx.Module):
|
|
28
|
+
|
|
29
|
+
def __init__(self,
|
|
30
|
+
vllm_config: VllmConfig,
|
|
31
|
+
rng: PRNGKey,
|
|
32
|
+
mesh: Mesh,
|
|
33
|
+
force_random_weights: bool = False):
|
|
34
|
+
logger.warning(
|
|
35
|
+
"🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨\n"
|
|
36
|
+
"Llama Guard 4 (JAX) is WIP: Only the text modality is currently implemented. "
|
|
37
|
+
"Multimodal inputs will fail.\n"
|
|
38
|
+
"🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨")
|
|
39
|
+
assert mesh is not None
|
|
40
|
+
|
|
41
|
+
self.vllm_config = vllm_config
|
|
42
|
+
self.vllm_config.model_config.dtype = torch.bfloat16
|
|
43
|
+
model_config = vllm_config.model_config
|
|
44
|
+
text_config = model_config.hf_config.text_config
|
|
45
|
+
|
|
46
|
+
self.mesh = mesh
|
|
47
|
+
self.is_verbose = getattr(self.vllm_config.additional_config,
|
|
48
|
+
"is_verbose", False)
|
|
49
|
+
|
|
50
|
+
self.use_qk_norm = getattr(text_config, "use_qk_norm", True)
|
|
51
|
+
|
|
52
|
+
vocab_size = model_config.get_vocab_size()
|
|
53
|
+
self.hidden_size = model_config.get_hidden_size()
|
|
54
|
+
|
|
55
|
+
self.dtype: jnp.dtype = jnp.bfloat16
|
|
56
|
+
|
|
57
|
+
self.num_layers: int = getattr(text_config, "num_layers", 48)
|
|
58
|
+
hidden_act: str = getattr(text_config, "hidden_act", "silu")
|
|
59
|
+
|
|
60
|
+
rms_norm_eps = getattr(text_config, "rms_norm_eps", 1e-5)
|
|
61
|
+
self.num_attention_heads = getattr(text_config, "num_attention_heads",
|
|
62
|
+
40)
|
|
63
|
+
self.num_key_value_heads = getattr(text_config, "num_key_value_heads",
|
|
64
|
+
8)
|
|
65
|
+
self.head_dim = getattr(text_config, "head_dim", 128)
|
|
66
|
+
|
|
67
|
+
intermediate_size = getattr(text_config, "intermediate_size", 8192)
|
|
68
|
+
|
|
69
|
+
self.rope_theta_text = getattr(text_config, "rope_theta", 500000.0)
|
|
70
|
+
self.rope_scaling = getattr(text_config, "rope_scaling")
|
|
71
|
+
|
|
72
|
+
self.rng = nnx.Rngs(rng)
|
|
73
|
+
|
|
74
|
+
self.embedder = Embedder(
|
|
75
|
+
vocab_size=vocab_size,
|
|
76
|
+
hidden_size=self.hidden_size,
|
|
77
|
+
dtype=self.dtype,
|
|
78
|
+
vd_sharding=(('data', 'model'), None),
|
|
79
|
+
rngs=self.rng,
|
|
80
|
+
random_init=force_random_weights,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
self.layers = []
|
|
84
|
+
|
|
85
|
+
for i in range(self.num_layers):
|
|
86
|
+
use_attention_rope = True
|
|
87
|
+
|
|
88
|
+
custom_module = DenseFFW(dtype=self.dtype,
|
|
89
|
+
hidden_act=hidden_act,
|
|
90
|
+
hidden_size=self.hidden_size,
|
|
91
|
+
intermediate_size=intermediate_size,
|
|
92
|
+
random_init=force_random_weights,
|
|
93
|
+
rngs=self.rng,
|
|
94
|
+
df_sharding=P(None, 'model'),
|
|
95
|
+
fd_sharding=P('model', None),
|
|
96
|
+
activation_ffw_td=P('data', None))
|
|
97
|
+
|
|
98
|
+
attn = Llama4Attention(
|
|
99
|
+
hidden_size=self.hidden_size,
|
|
100
|
+
dtype=self.dtype,
|
|
101
|
+
num_attention_heads=self.num_attention_heads,
|
|
102
|
+
num_key_value_heads=self.num_key_value_heads,
|
|
103
|
+
head_dim=self.head_dim,
|
|
104
|
+
rope_theta=self.rope_theta_text,
|
|
105
|
+
rope_scaling={
|
|
106
|
+
"scale_factor":
|
|
107
|
+
self.rope_scaling["factor"],
|
|
108
|
+
"low_freq_factor":
|
|
109
|
+
self.rope_scaling["low_freq_factor"],
|
|
110
|
+
"high_freq_factor":
|
|
111
|
+
self.rope_scaling["high_freq_factor"],
|
|
112
|
+
"original_max_position_embeddings":
|
|
113
|
+
self.rope_scaling["original_max_position_embeddings"]
|
|
114
|
+
},
|
|
115
|
+
rngs=self.rng,
|
|
116
|
+
rope_input_ordering="interleaved",
|
|
117
|
+
# TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
|
|
118
|
+
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
|
119
|
+
temperature_tuning=True,
|
|
120
|
+
temperature_tuning_scale=0.1,
|
|
121
|
+
temperature_tuning_floor_scale=8192,
|
|
122
|
+
use_qk_norm=self.use_qk_norm,
|
|
123
|
+
attention_chunk_size=None if use_attention_rope else 8192,
|
|
124
|
+
mesh=self.mesh,
|
|
125
|
+
random_init=force_random_weights,
|
|
126
|
+
activation_attention_td=('data', 'model'),
|
|
127
|
+
activation_q_td=('data', 'model'),
|
|
128
|
+
query_tnh=P('data', 'model', None),
|
|
129
|
+
keyvalue_skh=P('data', 'model', None),
|
|
130
|
+
activation_attention_out_td=('data', 'model'),
|
|
131
|
+
attn_o_tnh=P('data', 'model', None),
|
|
132
|
+
dnh_sharding=(None, 'model', None),
|
|
133
|
+
dkh_sharding=(None, 'model', None),
|
|
134
|
+
nhd_sharding=('model', None, None),
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
pre_attention_norm = RMSNorm(
|
|
138
|
+
dims=self.hidden_size,
|
|
139
|
+
random_init=force_random_weights,
|
|
140
|
+
epsilon=rms_norm_eps,
|
|
141
|
+
rngs=self.rng,
|
|
142
|
+
activation_ffw_td=('data', None),
|
|
143
|
+
with_scale=True,
|
|
144
|
+
dtype=self.dtype,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
pre_mlp_norm = RMSNorm(
|
|
148
|
+
dims=self.hidden_size,
|
|
149
|
+
activation_ffw_td=('data', None),
|
|
150
|
+
epsilon=rms_norm_eps,
|
|
151
|
+
rngs=self.rng,
|
|
152
|
+
with_scale=True,
|
|
153
|
+
dtype=self.dtype,
|
|
154
|
+
random_init=force_random_weights,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
block = TransformerBlock(custom_module=custom_module,
|
|
158
|
+
attn=attn,
|
|
159
|
+
pre_attention_norm=pre_attention_norm,
|
|
160
|
+
pre_mlp_norm=pre_mlp_norm,
|
|
161
|
+
use_attention_rope=use_attention_rope)
|
|
162
|
+
self.layers.append(block)
|
|
163
|
+
|
|
164
|
+
self.final_norm = RMSNorm(
|
|
165
|
+
dims=self.hidden_size,
|
|
166
|
+
activation_ffw_td=P(),
|
|
167
|
+
epsilon=rms_norm_eps,
|
|
168
|
+
rngs=self.rng,
|
|
169
|
+
with_scale=True,
|
|
170
|
+
dtype=self.dtype,
|
|
171
|
+
random_init=force_random_weights,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
self.lm_head = LMhead(vocab_size=vocab_size,
|
|
175
|
+
hidden_size=self.hidden_size,
|
|
176
|
+
dtype=self.dtype,
|
|
177
|
+
rngs=self.rng,
|
|
178
|
+
vd_sharding=(('data', 'model'), None),
|
|
179
|
+
dv_sharding=(None, ('data', 'model')),
|
|
180
|
+
random_init=force_random_weights)
|
|
181
|
+
if self.is_verbose:
|
|
182
|
+
self._print_model_architecture()
|
|
183
|
+
|
|
184
|
+
def _print_model_architecture(self):
|
|
185
|
+
|
|
186
|
+
logger.info("### Embedding ###")
|
|
187
|
+
nnx.display(self.embedder)
|
|
188
|
+
|
|
189
|
+
logger.info("\n### Layers ###")
|
|
190
|
+
for i, layer in enumerate(self.layers):
|
|
191
|
+
logger.info(f"\n--- Layer {i} ---")
|
|
192
|
+
nnx.display(layer)
|
|
193
|
+
|
|
194
|
+
logger.info("\n### LM Head ###")
|
|
195
|
+
nnx.display(self.lm_head)
|
|
196
|
+
|
|
197
|
+
def load_weights(self, rng: jax.Array, cache_dir: Optional[str] = None):
|
|
198
|
+
self.rng = nnx.Rngs(rng)
|
|
199
|
+
|
|
200
|
+
weight_loader = LlamaGuard4WeightLoader(
|
|
201
|
+
vllm_config=self.vllm_config,
|
|
202
|
+
hidden_size=self.hidden_size,
|
|
203
|
+
attn_heads=self.num_attention_heads,
|
|
204
|
+
num_key_value_heads=self.num_key_value_heads,
|
|
205
|
+
attn_head_dim=self.head_dim)
|
|
206
|
+
weight_loader.load_weights(self)
|
|
207
|
+
|
|
208
|
+
def __call__(
|
|
209
|
+
self,
|
|
210
|
+
kv_caches: List[jax.Array],
|
|
211
|
+
input_ids: jax.Array,
|
|
212
|
+
attention_metadata: AttentionMetadata,
|
|
213
|
+
inputs_embeds: Optional[jax.Array] = None,
|
|
214
|
+
layer_metadata_tuple: Optional[Tuple] = None,
|
|
215
|
+
lora_metadata: Optional[Any] = None,
|
|
216
|
+
*args,
|
|
217
|
+
) -> Tuple[List[KVCacheType], jax.Array]:
|
|
218
|
+
is_prefill = False
|
|
219
|
+
|
|
220
|
+
if inputs_embeds is not None:
|
|
221
|
+
x_TD = inputs_embeds
|
|
222
|
+
elif input_ids is not None:
|
|
223
|
+
x_TD = self.embedder.encode(input_ids)
|
|
224
|
+
else:
|
|
225
|
+
raise ValueError(
|
|
226
|
+
"Cannot run forward pass: Both input_ids and inputs_embeds are None."
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
for (i, block) in enumerate(self.layers):
|
|
230
|
+
kv_cache = kv_caches[i]
|
|
231
|
+
new_kv_cache, x_TD = block(x_TD, is_prefill, kv_cache,
|
|
232
|
+
attention_metadata)
|
|
233
|
+
jax.block_until_ready(x_TD)
|
|
234
|
+
kv_caches[i] = new_kv_cache
|
|
235
|
+
|
|
236
|
+
final_activation_TD = self.final_norm(x_TD)
|
|
237
|
+
|
|
238
|
+
return kv_caches, final_activation_TD, []
|
|
239
|
+
|
|
240
|
+
def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
|
|
241
|
+
logits_TV = jnp.dot(hidden_states,
|
|
242
|
+
self.lm_head.input_embedding_table_DV.value)
|
|
243
|
+
return logits_TV
|
|
244
|
+
|
|
245
|
+
def get_input_embeddings(
|
|
246
|
+
self,
|
|
247
|
+
input_ids: jax.Array,
|
|
248
|
+
multimodal_embeddings: Optional[List[jax.Array]] = None
|
|
249
|
+
) -> jax.Array:
|
|
250
|
+
"""
|
|
251
|
+
Computes the embeddings for text input (used for input to fusion).
|
|
252
|
+
"""
|
|
253
|
+
return self.embedder.encode(input_ids)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
class LlamaGuard4WeightLoader:
|
|
257
|
+
|
|
258
|
+
def __init__(self, vllm_config: VllmConfig, hidden_size, attn_heads,
|
|
259
|
+
num_key_value_heads, attn_head_dim):
|
|
260
|
+
self.names_and_weights_generator = model_weights_generator(
|
|
261
|
+
model_name_or_path=vllm_config.model_config.model,
|
|
262
|
+
framework="flax",
|
|
263
|
+
filter_regex="language_model",
|
|
264
|
+
download_dir=vllm_config.load_config.download_dir)
|
|
265
|
+
self.is_verbose = getattr(vllm_config.additional_config, "is_verbose",
|
|
266
|
+
False)
|
|
267
|
+
self._transpose_map = {
|
|
268
|
+
"q_proj": (2, 0, 1),
|
|
269
|
+
"k_proj": (2, 0, 1),
|
|
270
|
+
"v_proj": (2, 0, 1),
|
|
271
|
+
"o_proj": (1, 2, 0),
|
|
272
|
+
"lm_head": (1, 0),
|
|
273
|
+
"feed_forward.down_proj": (1, 0),
|
|
274
|
+
"feed_forward.gate_proj": (1, 0),
|
|
275
|
+
"feed_forward.up_proj": (1, 0),
|
|
276
|
+
"mlp.down_proj": (1, 0),
|
|
277
|
+
"mlp.gate_proj": (1, 0),
|
|
278
|
+
"mlp.up_proj": (1, 0),
|
|
279
|
+
}
|
|
280
|
+
self._weight_shape_map = {
|
|
281
|
+
"q_proj": (attn_heads, attn_head_dim, hidden_size),
|
|
282
|
+
"k_proj": (num_key_value_heads, attn_head_dim, hidden_size),
|
|
283
|
+
"v_proj": (num_key_value_heads, attn_head_dim, hidden_size),
|
|
284
|
+
"o_proj": (hidden_size, attn_heads, attn_head_dim),
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
self._loaded_to_standardized_keys = {
|
|
288
|
+
"language_model.model.embed_tokens.weight":
|
|
289
|
+
"embedder.input_embedding_table_VD",
|
|
290
|
+
"language_model.lm_head.weight":
|
|
291
|
+
"lm_head.input_embedding_table_DV",
|
|
292
|
+
"language_model.model.norm.weight":
|
|
293
|
+
"final_norm.scale",
|
|
294
|
+
"language_model.model.layers.*.input_layernorm.weight":
|
|
295
|
+
"layers.*.pre_attention_norm.scale",
|
|
296
|
+
"language_model.model.layers.*.post_attention_layernorm.weight":
|
|
297
|
+
"layers.*.pre_mlp_norm.scale",
|
|
298
|
+
"language_model.model.layers.*.self_attn.q_proj.weight":
|
|
299
|
+
"layers.*.attn.kernel_q_proj_DNH",
|
|
300
|
+
"language_model.model.layers.*.self_attn.k_proj.weight":
|
|
301
|
+
"layers.*.attn.kernel_k_proj_DKH",
|
|
302
|
+
"language_model.model.layers.*.self_attn.v_proj.weight":
|
|
303
|
+
"layers.*.attn.kernel_v_proj_DKH",
|
|
304
|
+
"language_model.model.layers.*.self_attn.o_proj.weight":
|
|
305
|
+
"layers.*.attn.kernel_o_proj_NHD",
|
|
306
|
+
"language_model.model.layers.*.feed_forward.gate_proj.weight":
|
|
307
|
+
"layers.*.custom_module.kernel_gating_DF",
|
|
308
|
+
"language_model.model.layers.*.feed_forward.up_proj.weight":
|
|
309
|
+
"layers.*.custom_module.kernel_up_proj_DF",
|
|
310
|
+
"language_model.model.layers.*.feed_forward.down_proj.weight":
|
|
311
|
+
"layers.*.custom_module.kernel_down_proj_FD",
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
def map_loaded_to_standardized_name(self, loaded_key: str) -> str:
|
|
315
|
+
if "layer" in loaded_key:
|
|
316
|
+
layer_num = re.search(r"layers\.(\d+)", loaded_key).group(1)
|
|
317
|
+
layer_key = re.sub(r"layers\.\d+", "layers.*", loaded_key)
|
|
318
|
+
mapped_key = self._loaded_to_standardized_keys.get(
|
|
319
|
+
layer_key, loaded_key)
|
|
320
|
+
mapped_key = re.sub(r"layers\.\*", f"layers.{layer_num}",
|
|
321
|
+
mapped_key)
|
|
322
|
+
else:
|
|
323
|
+
mapped_key = self._loaded_to_standardized_keys.get(
|
|
324
|
+
loaded_key, loaded_key)
|
|
325
|
+
return mapped_key
|
|
326
|
+
|
|
327
|
+
def load_weights(self, model_for_loading: nnx.Module):
|
|
328
|
+
model_params = nnx.state(model_for_loading)
|
|
329
|
+
with jax.default_device(jax.devices("cpu")[0]):
|
|
330
|
+
for loaded_name, loaded_weight in self.names_and_weights_generator:
|
|
331
|
+
if loaded_name.endswith(".bias"):
|
|
332
|
+
continue
|
|
333
|
+
if "vision_model" in loaded_name or "multi_modal_projector" in loaded_name:
|
|
334
|
+
continue
|
|
335
|
+
|
|
336
|
+
mapped_name = self.map_loaded_to_standardized_name(loaded_name)
|
|
337
|
+
model_weight = get_param(model_params, mapped_name)
|
|
338
|
+
|
|
339
|
+
if not loaded_name.endswith(".bias"):
|
|
340
|
+
# For other layers, continue to use the transpose_params helper.
|
|
341
|
+
loaded_weight = reshape_params(loaded_name, loaded_weight,
|
|
342
|
+
self._weight_shape_map)
|
|
343
|
+
loaded_weight = transpose_params(loaded_name,
|
|
344
|
+
loaded_weight,
|
|
345
|
+
self._transpose_map)
|
|
346
|
+
if model_weight.value.shape != loaded_weight.shape:
|
|
347
|
+
raise ValueError(
|
|
348
|
+
f"Loaded shape for {loaded_name}: {loaded_weight.shape} "
|
|
349
|
+
f"does not match model shape for {mapped_name}: {model_weight.value.shape}!"
|
|
350
|
+
)
|
|
351
|
+
logger.debug(
|
|
352
|
+
f"Transformed parameter {loaded_name} to {mapped_name}: {loaded_weight.shape} --> {model_weight.value.shape}"
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
model_weight.value = shard_put(loaded_weight,
|
|
356
|
+
model_weight.sharding,
|
|
357
|
+
mesh=model_for_loading.mesh)
|
|
358
|
+
if self.is_verbose:
|
|
359
|
+
print_param_info(model_weight, loaded_name)
|
|
360
|
+
|
|
361
|
+
nnx.update(model_for_loading, model_params)
|
|
@@ -368,7 +368,8 @@ class Qwen2ForCausalLM(nnx.Module):
|
|
|
368
368
|
"lm_head": "model.lm_head",
|
|
369
369
|
})
|
|
370
370
|
|
|
371
|
-
metadata_map = get_default_maps(self.vllm_config
|
|
371
|
+
metadata_map = get_default_maps(self.vllm_config.model_config,
|
|
372
|
+
self.mesh, mappings)
|
|
372
373
|
load_hf_weights(vllm_config=self.vllm_config,
|
|
373
374
|
model=self,
|
|
374
375
|
metadata_map=metadata_map,
|
|
@@ -1061,7 +1061,8 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
|
|
|
1061
1061
|
"lm_head": "language_model.model.lm_head",
|
|
1062
1062
|
})
|
|
1063
1063
|
|
|
1064
|
-
metadata_map = get_default_maps(self.vllm_config
|
|
1064
|
+
metadata_map = get_default_maps(self.vllm_config.model_config,
|
|
1065
|
+
self.mesh, mappings)
|
|
1065
1066
|
load_hf_weights(vllm_config=self.vllm_config,
|
|
1066
1067
|
model=self,
|
|
1067
1068
|
metadata_map=metadata_map,
|
|
@@ -295,7 +295,8 @@ class Qwen3ForCausalLM(nnx.Module):
|
|
|
295
295
|
"lm_head": "model.lm_head",
|
|
296
296
|
})
|
|
297
297
|
|
|
298
|
-
metadata_map = get_default_maps(self.vllm_config
|
|
298
|
+
metadata_map = get_default_maps(self.vllm_config.model_config,
|
|
299
|
+
self.mesh, mappings)
|
|
299
300
|
load_hf_weights(vllm_config=self.vllm_config,
|
|
300
301
|
model=self,
|
|
301
302
|
metadata_map=metadata_map,
|
|
@@ -18,7 +18,7 @@ from jax.sharding import Mesh, NamedSharding
|
|
|
18
18
|
from jax.sharding import PartitionSpec as P
|
|
19
19
|
from safetensors import safe_open
|
|
20
20
|
|
|
21
|
-
from tpu_inference import utils
|
|
21
|
+
from tpu_inference import envs, utils
|
|
22
22
|
from tpu_inference.logger import init_logger
|
|
23
23
|
from tpu_inference.models.jax.utils import file_utils
|
|
24
24
|
|
|
@@ -197,12 +197,11 @@ def shard_put(x: jax.Array, shardings, mesh: jax.sharding.Mesh) -> jax.Array:
|
|
|
197
197
|
return jax.device_put(x, shardings)
|
|
198
198
|
|
|
199
199
|
|
|
200
|
-
def get_default_maps(
|
|
200
|
+
def get_default_maps(model_config, mesh: Mesh,
|
|
201
201
|
name_map: dict[str, str]) -> MetadataMap:
|
|
202
202
|
"""Load weights from one model weights file to the model, run on single thread."""
|
|
203
203
|
sharding_size = mesh.shape["model"]
|
|
204
204
|
|
|
205
|
-
model_config = vllm_config.model_config
|
|
206
205
|
hf_config = model_config.hf_config
|
|
207
206
|
|
|
208
207
|
num_heads = hf_config.num_attention_heads
|
|
@@ -273,7 +272,8 @@ def _load_hf_weights_on_thread(vllm_config,
|
|
|
273
272
|
weights_file: str,
|
|
274
273
|
filter_regex: str | None = None,
|
|
275
274
|
keep_original_dtype_keys_regex: list[str]
|
|
276
|
-
| None = None
|
|
275
|
+
| None = None,
|
|
276
|
+
exclude_regex: list[str] | None = None):
|
|
277
277
|
name_map = metadata_map.name_map
|
|
278
278
|
reshape_keys = metadata_map.reshape_map
|
|
279
279
|
bias_reshape_keys = metadata_map.bias_reshape_map
|
|
@@ -298,6 +298,18 @@ def _load_hf_weights_on_thread(vllm_config,
|
|
|
298
298
|
for hf_key, hf_weight in model_weights_single_file_generator(
|
|
299
299
|
weights_file, framework="flax", filter_regex=filter_regex):
|
|
300
300
|
|
|
301
|
+
# Check if the key should be excluded
|
|
302
|
+
if exclude_regex:
|
|
303
|
+
should_exclude = False
|
|
304
|
+
for pattern in exclude_regex:
|
|
305
|
+
if re.search(pattern, hf_key):
|
|
306
|
+
logger.info(
|
|
307
|
+
f"Excluding {hf_key} based on pattern {pattern}")
|
|
308
|
+
should_exclude = True
|
|
309
|
+
break
|
|
310
|
+
if should_exclude:
|
|
311
|
+
continue
|
|
312
|
+
|
|
301
313
|
# Check if the key should retain its original dtype
|
|
302
314
|
keep_original_dtype = False
|
|
303
315
|
if keep_original_dtype_keys_regex:
|
|
@@ -408,7 +420,8 @@ def load_hf_weights(vllm_config,
|
|
|
408
420
|
mesh: Mesh,
|
|
409
421
|
filter_regex: str | None = None,
|
|
410
422
|
is_draft_model: bool = False,
|
|
411
|
-
keep_original_dtype_keys_regex: list[str] | None = None
|
|
423
|
+
keep_original_dtype_keys_regex: list[str] | None = None,
|
|
424
|
+
exclude_regex: list[str] | None = None):
|
|
412
425
|
"""Load weights from all model weights files to the model, run in multi threads."""
|
|
413
426
|
if is_draft_model:
|
|
414
427
|
model_path = vllm_config.speculative_config.draft_model_config.model
|
|
@@ -421,7 +434,7 @@ def load_hf_weights(vllm_config,
|
|
|
421
434
|
# NOTE(xiang): Disable multi-threading mode if running on multi-host.
|
|
422
435
|
# Because multi-threading would cause different JAX processes to load
|
|
423
436
|
# different weights at the same time.
|
|
424
|
-
if
|
|
437
|
+
if envs.TPU_MULTIHOST_BACKEND == "ray":
|
|
425
438
|
max_workers = 1
|
|
426
439
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
427
440
|
futures = [
|
|
@@ -433,8 +446,8 @@ def load_hf_weights(vllm_config,
|
|
|
433
446
|
mesh,
|
|
434
447
|
weights_file,
|
|
435
448
|
filter_regex=filter_regex,
|
|
436
|
-
keep_original_dtype_keys_regex=keep_original_dtype_keys_regex
|
|
437
|
-
|
|
449
|
+
keep_original_dtype_keys_regex=keep_original_dtype_keys_regex,
|
|
450
|
+
exclude_regex=exclude_regex) for weights_file in weights_files
|
|
438
451
|
]
|
|
439
452
|
for future in futures:
|
|
440
453
|
future.result()
|
|
@@ -120,8 +120,7 @@ class VllmModelWrapper:
|
|
|
120
120
|
|
|
121
121
|
# Load the vLLM model and wrap it into a new model whose forward
|
|
122
122
|
# function can calculate the hidden_state and logits.
|
|
123
|
-
|
|
124
|
-
with load_context, jax.default_device(available_devices[0]):
|
|
123
|
+
with load_context, jax.default_device(jax.devices("cpu")[0]):
|
|
125
124
|
vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
|
|
126
125
|
lora_manager = None
|
|
127
126
|
if vllm_config_for_load.lora_config is not None:
|
|
@@ -162,6 +161,7 @@ class VllmModelWrapper:
|
|
|
162
161
|
input_ids: jax.Array,
|
|
163
162
|
attn_metadata: AttentionMetadata,
|
|
164
163
|
input_embeds: jax.Array,
|
|
164
|
+
input_positions: jax.Array,
|
|
165
165
|
layer_name_to_kvcache_index: Sequence[Tuple[str, int]],
|
|
166
166
|
lora_metadata,
|
|
167
167
|
intermediate_tensors: JaxIntermediateTensors = None,
|
|
@@ -188,8 +188,8 @@ class VllmModelWrapper:
|
|
|
188
188
|
torch_view(params_and_buffers),
|
|
189
189
|
kwargs={
|
|
190
190
|
"input_ids": torch_view(input_ids),
|
|
191
|
-
"positions": torch_view(
|
|
192
|
-
"intermediate_tensors":
|
|
191
|
+
"positions": torch_view(input_positions),
|
|
192
|
+
"intermediate_tensors": None,
|
|
193
193
|
"inputs_embeds": None,
|
|
194
194
|
},
|
|
195
195
|
tie_weights=False,
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
|
2
2
|
|
|
3
|
-
import os
|
|
4
3
|
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
|
|
5
4
|
|
|
6
5
|
import jax.numpy as jnp
|
|
@@ -183,7 +182,7 @@ class TpuPlatform(Platform):
|
|
|
183
182
|
parallel_config.worker_cls = \
|
|
184
183
|
"tpu_inference.worker.tpu_worker.TPUWorker"
|
|
185
184
|
|
|
186
|
-
multihost_backend =
|
|
185
|
+
multihost_backend = envs.TPU_MULTIHOST_BACKEND
|
|
187
186
|
if not multihost_backend: # Single host
|
|
188
187
|
if parallel_config.pipeline_parallel_size == 1:
|
|
189
188
|
logger.info("Force using UniProcExecutor for JAX on \
|
|
@@ -267,3 +266,7 @@ class TpuPlatform(Platform):
|
|
|
267
266
|
Returns if the current platform needs to sync weight loader.
|
|
268
267
|
"""
|
|
269
268
|
return True
|
|
269
|
+
|
|
270
|
+
@classmethod
|
|
271
|
+
def support_hybrid_kv_cache(cls) -> bool:
|
|
272
|
+
return True
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import time
|
|
3
|
-
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
|
|
4
4
|
|
|
5
5
|
import jax
|
|
6
6
|
import jax.numpy as jnp
|
|
@@ -135,12 +135,6 @@ class CompilationManager:
|
|
|
135
135
|
ShardingAxisName.ATTN_DATA, )) if dp_size > 1 else None
|
|
136
136
|
|
|
137
137
|
# Keep existing pattern for complex array operations
|
|
138
|
-
block_tables = self.runner.block_table_cpu[:self.runner.max_num_reqs]
|
|
139
|
-
block_tables = block_tables.reshape(-1)
|
|
140
|
-
block_tables = device_array(self.runner.mesh,
|
|
141
|
-
block_tables,
|
|
142
|
-
sharding=dp_sharding)
|
|
143
|
-
|
|
144
138
|
seq_lens = self._create_dummy_tensor((self.runner.max_num_reqs, ),
|
|
145
139
|
jnp.int32, dp_sharding)
|
|
146
140
|
query_start_loc = self._create_dummy_tensor(
|
|
@@ -152,26 +146,45 @@ class CompilationManager:
|
|
|
152
146
|
request_distribution,
|
|
153
147
|
sharding=dp_sharding)
|
|
154
148
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
149
|
+
attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
|
|
150
|
+
uniform_attention_metadata: AttentionMetadata = None
|
|
151
|
+
for kv_cache_gid, kv_cache_group in enumerate(
|
|
152
|
+
self.runner.kv_cache_config.kv_cache_groups):
|
|
153
|
+
block_tables = self.runner.block_tables_cpu[
|
|
154
|
+
kv_cache_gid][:self.runner.max_num_reqs]
|
|
155
|
+
block_tables = block_tables.reshape(-1)
|
|
156
|
+
block_tables = device_array(self.runner.mesh,
|
|
157
|
+
block_tables,
|
|
158
|
+
sharding=dp_sharding)
|
|
159
|
+
|
|
160
|
+
attention_metadata_gid = AttentionMetadata(
|
|
161
|
+
input_positions=positions,
|
|
162
|
+
block_tables=block_tables,
|
|
163
|
+
seq_lens=seq_lens,
|
|
164
|
+
query_start_loc=query_start_loc,
|
|
165
|
+
request_distribution=request_distribution,
|
|
166
|
+
)
|
|
167
|
+
if not self.runner.use_hybrid_kvcache:
|
|
168
|
+
# all layers share the same attention metadata
|
|
169
|
+
uniform_attention_metadata = attention_metadata_gid
|
|
170
|
+
else:
|
|
171
|
+
for layer_name in kv_cache_group.layer_names:
|
|
172
|
+
attention_metadata_per_layer[
|
|
173
|
+
layer_name] = attention_metadata_gid
|
|
162
174
|
|
|
163
175
|
def model_fn_wrapper(
|
|
164
176
|
state,
|
|
165
177
|
kv_caches,
|
|
166
178
|
input_ids,
|
|
167
179
|
attention_metadata,
|
|
180
|
+
positions,
|
|
168
181
|
inputs_embeds,
|
|
169
182
|
layer_name_to_kvcache_index,
|
|
170
183
|
lora_metadata,
|
|
171
184
|
):
|
|
172
185
|
kv_caches, hidden_states, _ = self.runner.model_fn(
|
|
173
186
|
state, kv_caches, input_ids, attention_metadata, inputs_embeds,
|
|
174
|
-
layer_name_to_kvcache_index, lora_metadata)
|
|
187
|
+
positions, layer_name_to_kvcache_index, lora_metadata)
|
|
175
188
|
self.runner.kv_caches = kv_caches
|
|
176
189
|
return hidden_states
|
|
177
190
|
|
|
@@ -179,6 +192,10 @@ class CompilationManager:
|
|
|
179
192
|
self.runner.lora_config, np.array([num_tokens],
|
|
180
193
|
dtype=np.int32)):
|
|
181
194
|
lora_metadata = self.runner.lora_utils.extract_lora_metadata()
|
|
195
|
+
if self.runner.use_hybrid_kvcache:
|
|
196
|
+
attention_metadata = attention_metadata_per_layer
|
|
197
|
+
else:
|
|
198
|
+
attention_metadata = uniform_attention_metadata
|
|
182
199
|
self._run_compilation(
|
|
183
200
|
name,
|
|
184
201
|
model_fn_wrapper,
|
|
@@ -186,6 +203,7 @@ class CompilationManager:
|
|
|
186
203
|
self.runner.kv_caches,
|
|
187
204
|
input_ids,
|
|
188
205
|
attention_metadata,
|
|
206
|
+
positions,
|
|
189
207
|
inputs_embeds,
|
|
190
208
|
tuple(self.runner.layer_name_to_kvcache_index.items()),
|
|
191
209
|
lora_metadata,
|