tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.11.1.dev202511220812__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (40) hide show
  1. tests/lora/test_layers.py +0 -6
  2. tests/lora/utils.py +0 -8
  3. tpu_inference/__init__.py +22 -3
  4. tpu_inference/core/disagg_utils.py +6 -8
  5. tpu_inference/distributed/tpu_connector.py +2 -3
  6. tpu_inference/distributed/utils.py +3 -2
  7. tpu_inference/envs.py +1 -1
  8. tpu_inference/executors/ray_distributed_executor.py +4 -1
  9. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
  10. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +77 -54
  11. tpu_inference/layers/vllm/sharding.py +2 -2
  12. tpu_inference/lora/torch_punica_tpu.py +1 -2
  13. tpu_inference/models/common/model_loader.py +9 -9
  14. tpu_inference/models/jax/llama3.py +2 -1
  15. tpu_inference/models/jax/llama_eagle3.py +9 -5
  16. tpu_inference/models/jax/llama_guard_4.py +361 -0
  17. tpu_inference/models/jax/qwen2.py +2 -1
  18. tpu_inference/models/jax/qwen2_5_vl.py +2 -1
  19. tpu_inference/models/jax/qwen3.py +2 -1
  20. tpu_inference/models/jax/utils/weight_utils.py +21 -8
  21. tpu_inference/models/vllm/vllm_model_wrapper.py +4 -4
  22. tpu_inference/platforms/tpu_platform.py +5 -2
  23. tpu_inference/runner/compilation_manager.py +33 -15
  24. tpu_inference/runner/kv_cache_manager.py +8 -2
  25. tpu_inference/runner/tpu_runner.py +187 -99
  26. tpu_inference/spec_decode/jax/eagle3.py +2 -1
  27. tpu_inference/tpu_info.py +4 -3
  28. tpu_inference/utils.py +5 -4
  29. tpu_inference/worker/tpu_worker.py +158 -22
  30. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/METADATA +2 -2
  31. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/RECORD +34 -39
  32. tpu_inference/mock/__init__.py +0 -0
  33. tpu_inference/mock/vllm_config_utils.py +0 -28
  34. tpu_inference/mock/vllm_envs.py +0 -1219
  35. tpu_inference/mock/vllm_logger.py +0 -212
  36. tpu_inference/mock/vllm_logging_utils.py +0 -15
  37. tpu_inference/models/jax/phi3.py +0 -376
  38. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/WHEEL +0 -0
  39. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/licenses/LICENSE +0 -0
  40. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/top_level.txt +0 -0
@@ -194,13 +194,12 @@ class Eagle3LlamaModel(nnx.Module):
194
194
 
195
195
  def update_reshape_map_for_eagle3(vllm_config: VllmConfig,
196
196
  metadata_map: MetadataMap):
197
- model_config = vllm_config.model_config
197
+ model_config = vllm_config.speculative_config.draft_model_config
198
198
  hf_config = model_config.hf_config
199
199
 
200
200
  num_heads = hf_config.num_attention_heads
201
201
  num_kv_heads = hf_config.num_key_value_heads
202
- hidden_size = model_config.get_hidden_size()
203
-
202
+ hidden_size = hf_config.hidden_size
204
203
  head_dim_original = model_config.get_head_size()
205
204
 
206
205
  metadata_map.reshape_map.update({
@@ -312,7 +311,11 @@ class EagleLlama3ForCausalLM(nnx.Module):
312
311
  r".*d2t.*",
313
312
  ]
314
313
 
315
- metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
314
+ # `embed_tokens` is shared between target and draft.
315
+ exclude_regex = [r".*embed_tokens.*"]
316
+ metadata_map = get_default_maps(
317
+ self.vllm_config.speculative_config.draft_model_config, self.mesh,
318
+ mappings)
316
319
 
317
320
  update_reshape_map_for_eagle3(self.vllm_config, metadata_map)
318
321
 
@@ -322,7 +325,8 @@ class EagleLlama3ForCausalLM(nnx.Module):
322
325
  metadata_map=metadata_map,
323
326
  mesh=self.mesh,
324
327
  is_draft_model=True,
325
- keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
328
+ keep_original_dtype_keys_regex=keep_original_dtype_keys_regex,
329
+ exclude_regex=exclude_regex if exclude_regex else None)
326
330
 
327
331
  # If the embedding is not initialized, initialize it with a dummpy array here to pass jit compilation. The real weights will be shared from the target model in eagle3 class.
328
332
  if isinstance(self.model.embed_tokens.embedding.value,
@@ -0,0 +1,361 @@
1
+ import re
2
+ from typing import Any, List, Optional, Tuple
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import torch
7
+ from flax import nnx
8
+ from flax.typing import PRNGKey
9
+ from jax.sharding import Mesh
10
+ from jax.sharding import PartitionSpec as P
11
+ from vllm.config import VllmConfig
12
+
13
+ from tpu_inference.layers.jax.attention.attention import AttentionMetadata
14
+ from tpu_inference.layers.jax.attention.llama4_attention import Llama4Attention
15
+ from tpu_inference.layers.jax.constants import KVCacheType
16
+ from tpu_inference.layers.jax.layers import DenseFFW, Embedder, LMhead, RMSNorm
17
+ from tpu_inference.layers.jax.misc import shard_put
18
+ from tpu_inference.layers.jax.transformer_block import TransformerBlock
19
+ from tpu_inference.logger import init_logger
20
+ from tpu_inference.models.jax.utils.weight_utils import (
21
+ get_param, model_weights_generator, print_param_info, reshape_params,
22
+ transpose_params)
23
+
24
+ logger = init_logger(__name__)
25
+
26
+
27
+ class LlamaGuard4ForCausalLM(nnx.Module):
28
+
29
+ def __init__(self,
30
+ vllm_config: VllmConfig,
31
+ rng: PRNGKey,
32
+ mesh: Mesh,
33
+ force_random_weights: bool = False):
34
+ logger.warning(
35
+ "🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨\n"
36
+ "Llama Guard 4 (JAX) is WIP: Only the text modality is currently implemented. "
37
+ "Multimodal inputs will fail.\n"
38
+ "🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨")
39
+ assert mesh is not None
40
+
41
+ self.vllm_config = vllm_config
42
+ self.vllm_config.model_config.dtype = torch.bfloat16
43
+ model_config = vllm_config.model_config
44
+ text_config = model_config.hf_config.text_config
45
+
46
+ self.mesh = mesh
47
+ self.is_verbose = getattr(self.vllm_config.additional_config,
48
+ "is_verbose", False)
49
+
50
+ self.use_qk_norm = getattr(text_config, "use_qk_norm", True)
51
+
52
+ vocab_size = model_config.get_vocab_size()
53
+ self.hidden_size = model_config.get_hidden_size()
54
+
55
+ self.dtype: jnp.dtype = jnp.bfloat16
56
+
57
+ self.num_layers: int = getattr(text_config, "num_layers", 48)
58
+ hidden_act: str = getattr(text_config, "hidden_act", "silu")
59
+
60
+ rms_norm_eps = getattr(text_config, "rms_norm_eps", 1e-5)
61
+ self.num_attention_heads = getattr(text_config, "num_attention_heads",
62
+ 40)
63
+ self.num_key_value_heads = getattr(text_config, "num_key_value_heads",
64
+ 8)
65
+ self.head_dim = getattr(text_config, "head_dim", 128)
66
+
67
+ intermediate_size = getattr(text_config, "intermediate_size", 8192)
68
+
69
+ self.rope_theta_text = getattr(text_config, "rope_theta", 500000.0)
70
+ self.rope_scaling = getattr(text_config, "rope_scaling")
71
+
72
+ self.rng = nnx.Rngs(rng)
73
+
74
+ self.embedder = Embedder(
75
+ vocab_size=vocab_size,
76
+ hidden_size=self.hidden_size,
77
+ dtype=self.dtype,
78
+ vd_sharding=(('data', 'model'), None),
79
+ rngs=self.rng,
80
+ random_init=force_random_weights,
81
+ )
82
+
83
+ self.layers = []
84
+
85
+ for i in range(self.num_layers):
86
+ use_attention_rope = True
87
+
88
+ custom_module = DenseFFW(dtype=self.dtype,
89
+ hidden_act=hidden_act,
90
+ hidden_size=self.hidden_size,
91
+ intermediate_size=intermediate_size,
92
+ random_init=force_random_weights,
93
+ rngs=self.rng,
94
+ df_sharding=P(None, 'model'),
95
+ fd_sharding=P('model', None),
96
+ activation_ffw_td=P('data', None))
97
+
98
+ attn = Llama4Attention(
99
+ hidden_size=self.hidden_size,
100
+ dtype=self.dtype,
101
+ num_attention_heads=self.num_attention_heads,
102
+ num_key_value_heads=self.num_key_value_heads,
103
+ head_dim=self.head_dim,
104
+ rope_theta=self.rope_theta_text,
105
+ rope_scaling={
106
+ "scale_factor":
107
+ self.rope_scaling["factor"],
108
+ "low_freq_factor":
109
+ self.rope_scaling["low_freq_factor"],
110
+ "high_freq_factor":
111
+ self.rope_scaling["high_freq_factor"],
112
+ "original_max_position_embeddings":
113
+ self.rope_scaling["original_max_position_embeddings"]
114
+ },
115
+ rngs=self.rng,
116
+ rope_input_ordering="interleaved",
117
+ # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
118
+ kv_cache_dtype=vllm_config.cache_config.cache_dtype,
119
+ temperature_tuning=True,
120
+ temperature_tuning_scale=0.1,
121
+ temperature_tuning_floor_scale=8192,
122
+ use_qk_norm=self.use_qk_norm,
123
+ attention_chunk_size=None if use_attention_rope else 8192,
124
+ mesh=self.mesh,
125
+ random_init=force_random_weights,
126
+ activation_attention_td=('data', 'model'),
127
+ activation_q_td=('data', 'model'),
128
+ query_tnh=P('data', 'model', None),
129
+ keyvalue_skh=P('data', 'model', None),
130
+ activation_attention_out_td=('data', 'model'),
131
+ attn_o_tnh=P('data', 'model', None),
132
+ dnh_sharding=(None, 'model', None),
133
+ dkh_sharding=(None, 'model', None),
134
+ nhd_sharding=('model', None, None),
135
+ )
136
+
137
+ pre_attention_norm = RMSNorm(
138
+ dims=self.hidden_size,
139
+ random_init=force_random_weights,
140
+ epsilon=rms_norm_eps,
141
+ rngs=self.rng,
142
+ activation_ffw_td=('data', None),
143
+ with_scale=True,
144
+ dtype=self.dtype,
145
+ )
146
+
147
+ pre_mlp_norm = RMSNorm(
148
+ dims=self.hidden_size,
149
+ activation_ffw_td=('data', None),
150
+ epsilon=rms_norm_eps,
151
+ rngs=self.rng,
152
+ with_scale=True,
153
+ dtype=self.dtype,
154
+ random_init=force_random_weights,
155
+ )
156
+
157
+ block = TransformerBlock(custom_module=custom_module,
158
+ attn=attn,
159
+ pre_attention_norm=pre_attention_norm,
160
+ pre_mlp_norm=pre_mlp_norm,
161
+ use_attention_rope=use_attention_rope)
162
+ self.layers.append(block)
163
+
164
+ self.final_norm = RMSNorm(
165
+ dims=self.hidden_size,
166
+ activation_ffw_td=P(),
167
+ epsilon=rms_norm_eps,
168
+ rngs=self.rng,
169
+ with_scale=True,
170
+ dtype=self.dtype,
171
+ random_init=force_random_weights,
172
+ )
173
+
174
+ self.lm_head = LMhead(vocab_size=vocab_size,
175
+ hidden_size=self.hidden_size,
176
+ dtype=self.dtype,
177
+ rngs=self.rng,
178
+ vd_sharding=(('data', 'model'), None),
179
+ dv_sharding=(None, ('data', 'model')),
180
+ random_init=force_random_weights)
181
+ if self.is_verbose:
182
+ self._print_model_architecture()
183
+
184
+ def _print_model_architecture(self):
185
+
186
+ logger.info("### Embedding ###")
187
+ nnx.display(self.embedder)
188
+
189
+ logger.info("\n### Layers ###")
190
+ for i, layer in enumerate(self.layers):
191
+ logger.info(f"\n--- Layer {i} ---")
192
+ nnx.display(layer)
193
+
194
+ logger.info("\n### LM Head ###")
195
+ nnx.display(self.lm_head)
196
+
197
+ def load_weights(self, rng: jax.Array, cache_dir: Optional[str] = None):
198
+ self.rng = nnx.Rngs(rng)
199
+
200
+ weight_loader = LlamaGuard4WeightLoader(
201
+ vllm_config=self.vllm_config,
202
+ hidden_size=self.hidden_size,
203
+ attn_heads=self.num_attention_heads,
204
+ num_key_value_heads=self.num_key_value_heads,
205
+ attn_head_dim=self.head_dim)
206
+ weight_loader.load_weights(self)
207
+
208
+ def __call__(
209
+ self,
210
+ kv_caches: List[jax.Array],
211
+ input_ids: jax.Array,
212
+ attention_metadata: AttentionMetadata,
213
+ inputs_embeds: Optional[jax.Array] = None,
214
+ layer_metadata_tuple: Optional[Tuple] = None,
215
+ lora_metadata: Optional[Any] = None,
216
+ *args,
217
+ ) -> Tuple[List[KVCacheType], jax.Array]:
218
+ is_prefill = False
219
+
220
+ if inputs_embeds is not None:
221
+ x_TD = inputs_embeds
222
+ elif input_ids is not None:
223
+ x_TD = self.embedder.encode(input_ids)
224
+ else:
225
+ raise ValueError(
226
+ "Cannot run forward pass: Both input_ids and inputs_embeds are None."
227
+ )
228
+
229
+ for (i, block) in enumerate(self.layers):
230
+ kv_cache = kv_caches[i]
231
+ new_kv_cache, x_TD = block(x_TD, is_prefill, kv_cache,
232
+ attention_metadata)
233
+ jax.block_until_ready(x_TD)
234
+ kv_caches[i] = new_kv_cache
235
+
236
+ final_activation_TD = self.final_norm(x_TD)
237
+
238
+ return kv_caches, final_activation_TD, []
239
+
240
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
241
+ logits_TV = jnp.dot(hidden_states,
242
+ self.lm_head.input_embedding_table_DV.value)
243
+ return logits_TV
244
+
245
+ def get_input_embeddings(
246
+ self,
247
+ input_ids: jax.Array,
248
+ multimodal_embeddings: Optional[List[jax.Array]] = None
249
+ ) -> jax.Array:
250
+ """
251
+ Computes the embeddings for text input (used for input to fusion).
252
+ """
253
+ return self.embedder.encode(input_ids)
254
+
255
+
256
+ class LlamaGuard4WeightLoader:
257
+
258
+ def __init__(self, vllm_config: VllmConfig, hidden_size, attn_heads,
259
+ num_key_value_heads, attn_head_dim):
260
+ self.names_and_weights_generator = model_weights_generator(
261
+ model_name_or_path=vllm_config.model_config.model,
262
+ framework="flax",
263
+ filter_regex="language_model",
264
+ download_dir=vllm_config.load_config.download_dir)
265
+ self.is_verbose = getattr(vllm_config.additional_config, "is_verbose",
266
+ False)
267
+ self._transpose_map = {
268
+ "q_proj": (2, 0, 1),
269
+ "k_proj": (2, 0, 1),
270
+ "v_proj": (2, 0, 1),
271
+ "o_proj": (1, 2, 0),
272
+ "lm_head": (1, 0),
273
+ "feed_forward.down_proj": (1, 0),
274
+ "feed_forward.gate_proj": (1, 0),
275
+ "feed_forward.up_proj": (1, 0),
276
+ "mlp.down_proj": (1, 0),
277
+ "mlp.gate_proj": (1, 0),
278
+ "mlp.up_proj": (1, 0),
279
+ }
280
+ self._weight_shape_map = {
281
+ "q_proj": (attn_heads, attn_head_dim, hidden_size),
282
+ "k_proj": (num_key_value_heads, attn_head_dim, hidden_size),
283
+ "v_proj": (num_key_value_heads, attn_head_dim, hidden_size),
284
+ "o_proj": (hidden_size, attn_heads, attn_head_dim),
285
+ }
286
+
287
+ self._loaded_to_standardized_keys = {
288
+ "language_model.model.embed_tokens.weight":
289
+ "embedder.input_embedding_table_VD",
290
+ "language_model.lm_head.weight":
291
+ "lm_head.input_embedding_table_DV",
292
+ "language_model.model.norm.weight":
293
+ "final_norm.scale",
294
+ "language_model.model.layers.*.input_layernorm.weight":
295
+ "layers.*.pre_attention_norm.scale",
296
+ "language_model.model.layers.*.post_attention_layernorm.weight":
297
+ "layers.*.pre_mlp_norm.scale",
298
+ "language_model.model.layers.*.self_attn.q_proj.weight":
299
+ "layers.*.attn.kernel_q_proj_DNH",
300
+ "language_model.model.layers.*.self_attn.k_proj.weight":
301
+ "layers.*.attn.kernel_k_proj_DKH",
302
+ "language_model.model.layers.*.self_attn.v_proj.weight":
303
+ "layers.*.attn.kernel_v_proj_DKH",
304
+ "language_model.model.layers.*.self_attn.o_proj.weight":
305
+ "layers.*.attn.kernel_o_proj_NHD",
306
+ "language_model.model.layers.*.feed_forward.gate_proj.weight":
307
+ "layers.*.custom_module.kernel_gating_DF",
308
+ "language_model.model.layers.*.feed_forward.up_proj.weight":
309
+ "layers.*.custom_module.kernel_up_proj_DF",
310
+ "language_model.model.layers.*.feed_forward.down_proj.weight":
311
+ "layers.*.custom_module.kernel_down_proj_FD",
312
+ }
313
+
314
+ def map_loaded_to_standardized_name(self, loaded_key: str) -> str:
315
+ if "layer" in loaded_key:
316
+ layer_num = re.search(r"layers\.(\d+)", loaded_key).group(1)
317
+ layer_key = re.sub(r"layers\.\d+", "layers.*", loaded_key)
318
+ mapped_key = self._loaded_to_standardized_keys.get(
319
+ layer_key, loaded_key)
320
+ mapped_key = re.sub(r"layers\.\*", f"layers.{layer_num}",
321
+ mapped_key)
322
+ else:
323
+ mapped_key = self._loaded_to_standardized_keys.get(
324
+ loaded_key, loaded_key)
325
+ return mapped_key
326
+
327
+ def load_weights(self, model_for_loading: nnx.Module):
328
+ model_params = nnx.state(model_for_loading)
329
+ with jax.default_device(jax.devices("cpu")[0]):
330
+ for loaded_name, loaded_weight in self.names_and_weights_generator:
331
+ if loaded_name.endswith(".bias"):
332
+ continue
333
+ if "vision_model" in loaded_name or "multi_modal_projector" in loaded_name:
334
+ continue
335
+
336
+ mapped_name = self.map_loaded_to_standardized_name(loaded_name)
337
+ model_weight = get_param(model_params, mapped_name)
338
+
339
+ if not loaded_name.endswith(".bias"):
340
+ # For other layers, continue to use the transpose_params helper.
341
+ loaded_weight = reshape_params(loaded_name, loaded_weight,
342
+ self._weight_shape_map)
343
+ loaded_weight = transpose_params(loaded_name,
344
+ loaded_weight,
345
+ self._transpose_map)
346
+ if model_weight.value.shape != loaded_weight.shape:
347
+ raise ValueError(
348
+ f"Loaded shape for {loaded_name}: {loaded_weight.shape} "
349
+ f"does not match model shape for {mapped_name}: {model_weight.value.shape}!"
350
+ )
351
+ logger.debug(
352
+ f"Transformed parameter {loaded_name} to {mapped_name}: {loaded_weight.shape} --> {model_weight.value.shape}"
353
+ )
354
+
355
+ model_weight.value = shard_put(loaded_weight,
356
+ model_weight.sharding,
357
+ mesh=model_for_loading.mesh)
358
+ if self.is_verbose:
359
+ print_param_info(model_weight, loaded_name)
360
+
361
+ nnx.update(model_for_loading, model_params)
@@ -368,7 +368,8 @@ class Qwen2ForCausalLM(nnx.Module):
368
368
  "lm_head": "model.lm_head",
369
369
  })
370
370
 
371
- metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
371
+ metadata_map = get_default_maps(self.vllm_config.model_config,
372
+ self.mesh, mappings)
372
373
  load_hf_weights(vllm_config=self.vllm_config,
373
374
  model=self,
374
375
  metadata_map=metadata_map,
@@ -1061,7 +1061,8 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
1061
1061
  "lm_head": "language_model.model.lm_head",
1062
1062
  })
1063
1063
 
1064
- metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
1064
+ metadata_map = get_default_maps(self.vllm_config.model_config,
1065
+ self.mesh, mappings)
1065
1066
  load_hf_weights(vllm_config=self.vllm_config,
1066
1067
  model=self,
1067
1068
  metadata_map=metadata_map,
@@ -295,7 +295,8 @@ class Qwen3ForCausalLM(nnx.Module):
295
295
  "lm_head": "model.lm_head",
296
296
  })
297
297
 
298
- metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
298
+ metadata_map = get_default_maps(self.vllm_config.model_config,
299
+ self.mesh, mappings)
299
300
  load_hf_weights(vllm_config=self.vllm_config,
300
301
  model=self,
301
302
  metadata_map=metadata_map,
@@ -18,7 +18,7 @@ from jax.sharding import Mesh, NamedSharding
18
18
  from jax.sharding import PartitionSpec as P
19
19
  from safetensors import safe_open
20
20
 
21
- from tpu_inference import utils
21
+ from tpu_inference import envs, utils
22
22
  from tpu_inference.logger import init_logger
23
23
  from tpu_inference.models.jax.utils import file_utils
24
24
 
@@ -197,12 +197,11 @@ def shard_put(x: jax.Array, shardings, mesh: jax.sharding.Mesh) -> jax.Array:
197
197
  return jax.device_put(x, shardings)
198
198
 
199
199
 
200
- def get_default_maps(vllm_config, mesh: Mesh,
200
+ def get_default_maps(model_config, mesh: Mesh,
201
201
  name_map: dict[str, str]) -> MetadataMap:
202
202
  """Load weights from one model weights file to the model, run on single thread."""
203
203
  sharding_size = mesh.shape["model"]
204
204
 
205
- model_config = vllm_config.model_config
206
205
  hf_config = model_config.hf_config
207
206
 
208
207
  num_heads = hf_config.num_attention_heads
@@ -273,7 +272,8 @@ def _load_hf_weights_on_thread(vllm_config,
273
272
  weights_file: str,
274
273
  filter_regex: str | None = None,
275
274
  keep_original_dtype_keys_regex: list[str]
276
- | None = None):
275
+ | None = None,
276
+ exclude_regex: list[str] | None = None):
277
277
  name_map = metadata_map.name_map
278
278
  reshape_keys = metadata_map.reshape_map
279
279
  bias_reshape_keys = metadata_map.bias_reshape_map
@@ -298,6 +298,18 @@ def _load_hf_weights_on_thread(vllm_config,
298
298
  for hf_key, hf_weight in model_weights_single_file_generator(
299
299
  weights_file, framework="flax", filter_regex=filter_regex):
300
300
 
301
+ # Check if the key should be excluded
302
+ if exclude_regex:
303
+ should_exclude = False
304
+ for pattern in exclude_regex:
305
+ if re.search(pattern, hf_key):
306
+ logger.info(
307
+ f"Excluding {hf_key} based on pattern {pattern}")
308
+ should_exclude = True
309
+ break
310
+ if should_exclude:
311
+ continue
312
+
301
313
  # Check if the key should retain its original dtype
302
314
  keep_original_dtype = False
303
315
  if keep_original_dtype_keys_regex:
@@ -408,7 +420,8 @@ def load_hf_weights(vllm_config,
408
420
  mesh: Mesh,
409
421
  filter_regex: str | None = None,
410
422
  is_draft_model: bool = False,
411
- keep_original_dtype_keys_regex: list[str] | None = None):
423
+ keep_original_dtype_keys_regex: list[str] | None = None,
424
+ exclude_regex: list[str] | None = None):
412
425
  """Load weights from all model weights files to the model, run in multi threads."""
413
426
  if is_draft_model:
414
427
  model_path = vllm_config.speculative_config.draft_model_config.model
@@ -421,7 +434,7 @@ def load_hf_weights(vllm_config,
421
434
  # NOTE(xiang): Disable multi-threading mode if running on multi-host.
422
435
  # Because multi-threading would cause different JAX processes to load
423
436
  # different weights at the same time.
424
- if os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
437
+ if envs.TPU_MULTIHOST_BACKEND == "ray":
425
438
  max_workers = 1
426
439
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
427
440
  futures = [
@@ -433,8 +446,8 @@ def load_hf_weights(vllm_config,
433
446
  mesh,
434
447
  weights_file,
435
448
  filter_regex=filter_regex,
436
- keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
437
- for weights_file in weights_files
449
+ keep_original_dtype_keys_regex=keep_original_dtype_keys_regex,
450
+ exclude_regex=exclude_regex) for weights_file in weights_files
438
451
  ]
439
452
  for future in futures:
440
453
  future.result()
@@ -120,8 +120,7 @@ class VllmModelWrapper:
120
120
 
121
121
  # Load the vLLM model and wrap it into a new model whose forward
122
122
  # function can calculate the hidden_state and logits.
123
- available_devices = self.mesh.devices.flatten()
124
- with load_context, jax.default_device(available_devices[0]):
123
+ with load_context, jax.default_device(jax.devices("cpu")[0]):
125
124
  vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
126
125
  lora_manager = None
127
126
  if vllm_config_for_load.lora_config is not None:
@@ -162,6 +161,7 @@ class VllmModelWrapper:
162
161
  input_ids: jax.Array,
163
162
  attn_metadata: AttentionMetadata,
164
163
  input_embeds: jax.Array,
164
+ input_positions: jax.Array,
165
165
  layer_name_to_kvcache_index: Sequence[Tuple[str, int]],
166
166
  lora_metadata,
167
167
  intermediate_tensors: JaxIntermediateTensors = None,
@@ -188,8 +188,8 @@ class VllmModelWrapper:
188
188
  torch_view(params_and_buffers),
189
189
  kwargs={
190
190
  "input_ids": torch_view(input_ids),
191
- "positions": torch_view(attn_metadata.input_positions),
192
- "intermediate_tensors": intermediate_tensors,
191
+ "positions": torch_view(input_positions),
192
+ "intermediate_tensors": None,
193
193
  "inputs_embeds": None,
194
194
  },
195
195
  tie_weights=False,
@@ -1,6 +1,5 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
- import os
4
3
  from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
5
4
 
6
5
  import jax.numpy as jnp
@@ -183,7 +182,7 @@ class TpuPlatform(Platform):
183
182
  parallel_config.worker_cls = \
184
183
  "tpu_inference.worker.tpu_worker.TPUWorker"
185
184
 
186
- multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
185
+ multihost_backend = envs.TPU_MULTIHOST_BACKEND
187
186
  if not multihost_backend: # Single host
188
187
  if parallel_config.pipeline_parallel_size == 1:
189
188
  logger.info("Force using UniProcExecutor for JAX on \
@@ -267,3 +266,7 @@ class TpuPlatform(Platform):
267
266
  Returns if the current platform needs to sync weight loader.
268
267
  """
269
268
  return True
269
+
270
+ @classmethod
271
+ def support_hybrid_kv_cache(cls) -> bool:
272
+ return True
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  import time
3
- from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
3
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
4
4
 
5
5
  import jax
6
6
  import jax.numpy as jnp
@@ -135,12 +135,6 @@ class CompilationManager:
135
135
  ShardingAxisName.ATTN_DATA, )) if dp_size > 1 else None
136
136
 
137
137
  # Keep existing pattern for complex array operations
138
- block_tables = self.runner.block_table_cpu[:self.runner.max_num_reqs]
139
- block_tables = block_tables.reshape(-1)
140
- block_tables = device_array(self.runner.mesh,
141
- block_tables,
142
- sharding=dp_sharding)
143
-
144
138
  seq_lens = self._create_dummy_tensor((self.runner.max_num_reqs, ),
145
139
  jnp.int32, dp_sharding)
146
140
  query_start_loc = self._create_dummy_tensor(
@@ -152,26 +146,45 @@ class CompilationManager:
152
146
  request_distribution,
153
147
  sharding=dp_sharding)
154
148
 
155
- attention_metadata = AttentionMetadata(
156
- input_positions=positions,
157
- block_tables=block_tables,
158
- seq_lens=seq_lens,
159
- query_start_loc=query_start_loc,
160
- request_distribution=request_distribution,
161
- )
149
+ attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
150
+ uniform_attention_metadata: AttentionMetadata = None
151
+ for kv_cache_gid, kv_cache_group in enumerate(
152
+ self.runner.kv_cache_config.kv_cache_groups):
153
+ block_tables = self.runner.block_tables_cpu[
154
+ kv_cache_gid][:self.runner.max_num_reqs]
155
+ block_tables = block_tables.reshape(-1)
156
+ block_tables = device_array(self.runner.mesh,
157
+ block_tables,
158
+ sharding=dp_sharding)
159
+
160
+ attention_metadata_gid = AttentionMetadata(
161
+ input_positions=positions,
162
+ block_tables=block_tables,
163
+ seq_lens=seq_lens,
164
+ query_start_loc=query_start_loc,
165
+ request_distribution=request_distribution,
166
+ )
167
+ if not self.runner.use_hybrid_kvcache:
168
+ # all layers share the same attention metadata
169
+ uniform_attention_metadata = attention_metadata_gid
170
+ else:
171
+ for layer_name in kv_cache_group.layer_names:
172
+ attention_metadata_per_layer[
173
+ layer_name] = attention_metadata_gid
162
174
 
163
175
  def model_fn_wrapper(
164
176
  state,
165
177
  kv_caches,
166
178
  input_ids,
167
179
  attention_metadata,
180
+ positions,
168
181
  inputs_embeds,
169
182
  layer_name_to_kvcache_index,
170
183
  lora_metadata,
171
184
  ):
172
185
  kv_caches, hidden_states, _ = self.runner.model_fn(
173
186
  state, kv_caches, input_ids, attention_metadata, inputs_embeds,
174
- layer_name_to_kvcache_index, lora_metadata)
187
+ positions, layer_name_to_kvcache_index, lora_metadata)
175
188
  self.runner.kv_caches = kv_caches
176
189
  return hidden_states
177
190
 
@@ -179,6 +192,10 @@ class CompilationManager:
179
192
  self.runner.lora_config, np.array([num_tokens],
180
193
  dtype=np.int32)):
181
194
  lora_metadata = self.runner.lora_utils.extract_lora_metadata()
195
+ if self.runner.use_hybrid_kvcache:
196
+ attention_metadata = attention_metadata_per_layer
197
+ else:
198
+ attention_metadata = uniform_attention_metadata
182
199
  self._run_compilation(
183
200
  name,
184
201
  model_fn_wrapper,
@@ -186,6 +203,7 @@ class CompilationManager:
186
203
  self.runner.kv_caches,
187
204
  input_ids,
188
205
  attention_metadata,
206
+ positions,
189
207
  inputs_embeds,
190
208
  tuple(self.runner.layer_name_to_kvcache_index.items()),
191
209
  lora_metadata,