tpu-inference 0.11.1.dev202511130813__py3-none-any.whl → 0.11.1.dev202511220812__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (58) hide show
  1. tests/lora/test_layers.py +0 -6
  2. tests/lora/utils.py +0 -8
  3. tests/test_envs.py +182 -0
  4. tests/test_utils.py +23 -14
  5. tpu_inference/__init__.py +22 -3
  6. tpu_inference/core/core_tpu.py +17 -9
  7. tpu_inference/core/disagg_utils.py +6 -8
  8. tpu_inference/distributed/tpu_connector.py +2 -3
  9. tpu_inference/distributed/utils.py +3 -2
  10. tpu_inference/envs.py +1 -1
  11. tpu_inference/executors/ray_distributed_executor.py +27 -11
  12. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +77 -54
  13. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +110 -64
  14. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +7 -0
  15. tpu_inference/layers/{jax → common}/attention_interface.py +1 -1
  16. tpu_inference/layers/common/quant_methods.py +8 -0
  17. tpu_inference/layers/jax/attention/attention.py +1 -1
  18. tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
  19. tpu_inference/layers/jax/sample/sampling.py +2 -2
  20. tpu_inference/layers/vllm/attention.py +1 -1
  21. tpu_inference/layers/vllm/quantization/__init__.py +7 -3
  22. tpu_inference/layers/vllm/quantization/awq.py +4 -3
  23. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -2
  24. tpu_inference/layers/vllm/quantization/mxfp4.py +266 -0
  25. tpu_inference/layers/vllm/quantization/unquantized.py +4 -3
  26. tpu_inference/layers/vllm/sharding.py +2 -2
  27. tpu_inference/lora/torch_punica_tpu.py +1 -2
  28. tpu_inference/models/common/model_loader.py +12 -11
  29. tpu_inference/models/jax/llama3.py +4 -3
  30. tpu_inference/models/jax/llama_eagle3.py +9 -5
  31. tpu_inference/models/jax/llama_guard_4.py +361 -0
  32. tpu_inference/models/jax/qwen2.py +3 -2
  33. tpu_inference/models/jax/qwen2_5_vl.py +4 -3
  34. tpu_inference/models/jax/qwen3.py +3 -2
  35. tpu_inference/models/jax/utils/weight_utils.py +21 -8
  36. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -10
  37. tpu_inference/platforms/tpu_platform.py +17 -7
  38. tpu_inference/runner/compilation_manager.py +37 -17
  39. tpu_inference/runner/kv_cache.py +1 -1
  40. tpu_inference/runner/kv_cache_manager.py +8 -2
  41. tpu_inference/runner/tpu_runner.py +199 -87
  42. tpu_inference/spec_decode/jax/eagle3.py +2 -1
  43. tpu_inference/tpu_info.py +4 -3
  44. tpu_inference/utils.py +7 -6
  45. tpu_inference/worker/tpu_worker.py +159 -23
  46. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/METADATA +2 -2
  47. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/RECORD +52 -54
  48. tpu_inference/mock/__init__.py +0 -0
  49. tpu_inference/mock/vllm_config_utils.py +0 -28
  50. tpu_inference/mock/vllm_envs.py +0 -1219
  51. tpu_inference/mock/vllm_logger.py +0 -212
  52. tpu_inference/mock/vllm_logging_utils.py +0 -15
  53. tpu_inference/models/jax/phi3.py +0 -376
  54. /tpu_inference/layers/{jax → common}/binary_search.py +0 -0
  55. /tpu_inference/layers/{jax → common}/sharding.py +0 -0
  56. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/WHEEL +0 -0
  57. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/licenses/LICENSE +0 -0
  58. {tpu_inference-0.11.1.dev202511130813.dist-info → tpu_inference-0.11.1.dev202511220812.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,361 @@
1
+ import re
2
+ from typing import Any, List, Optional, Tuple
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import torch
7
+ from flax import nnx
8
+ from flax.typing import PRNGKey
9
+ from jax.sharding import Mesh
10
+ from jax.sharding import PartitionSpec as P
11
+ from vllm.config import VllmConfig
12
+
13
+ from tpu_inference.layers.jax.attention.attention import AttentionMetadata
14
+ from tpu_inference.layers.jax.attention.llama4_attention import Llama4Attention
15
+ from tpu_inference.layers.jax.constants import KVCacheType
16
+ from tpu_inference.layers.jax.layers import DenseFFW, Embedder, LMhead, RMSNorm
17
+ from tpu_inference.layers.jax.misc import shard_put
18
+ from tpu_inference.layers.jax.transformer_block import TransformerBlock
19
+ from tpu_inference.logger import init_logger
20
+ from tpu_inference.models.jax.utils.weight_utils import (
21
+ get_param, model_weights_generator, print_param_info, reshape_params,
22
+ transpose_params)
23
+
24
+ logger = init_logger(__name__)
25
+
26
+
27
+ class LlamaGuard4ForCausalLM(nnx.Module):
28
+
29
+ def __init__(self,
30
+ vllm_config: VllmConfig,
31
+ rng: PRNGKey,
32
+ mesh: Mesh,
33
+ force_random_weights: bool = False):
34
+ logger.warning(
35
+ "🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨\n"
36
+ "Llama Guard 4 (JAX) is WIP: Only the text modality is currently implemented. "
37
+ "Multimodal inputs will fail.\n"
38
+ "🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨")
39
+ assert mesh is not None
40
+
41
+ self.vllm_config = vllm_config
42
+ self.vllm_config.model_config.dtype = torch.bfloat16
43
+ model_config = vllm_config.model_config
44
+ text_config = model_config.hf_config.text_config
45
+
46
+ self.mesh = mesh
47
+ self.is_verbose = getattr(self.vllm_config.additional_config,
48
+ "is_verbose", False)
49
+
50
+ self.use_qk_norm = getattr(text_config, "use_qk_norm", True)
51
+
52
+ vocab_size = model_config.get_vocab_size()
53
+ self.hidden_size = model_config.get_hidden_size()
54
+
55
+ self.dtype: jnp.dtype = jnp.bfloat16
56
+
57
+ self.num_layers: int = getattr(text_config, "num_layers", 48)
58
+ hidden_act: str = getattr(text_config, "hidden_act", "silu")
59
+
60
+ rms_norm_eps = getattr(text_config, "rms_norm_eps", 1e-5)
61
+ self.num_attention_heads = getattr(text_config, "num_attention_heads",
62
+ 40)
63
+ self.num_key_value_heads = getattr(text_config, "num_key_value_heads",
64
+ 8)
65
+ self.head_dim = getattr(text_config, "head_dim", 128)
66
+
67
+ intermediate_size = getattr(text_config, "intermediate_size", 8192)
68
+
69
+ self.rope_theta_text = getattr(text_config, "rope_theta", 500000.0)
70
+ self.rope_scaling = getattr(text_config, "rope_scaling")
71
+
72
+ self.rng = nnx.Rngs(rng)
73
+
74
+ self.embedder = Embedder(
75
+ vocab_size=vocab_size,
76
+ hidden_size=self.hidden_size,
77
+ dtype=self.dtype,
78
+ vd_sharding=(('data', 'model'), None),
79
+ rngs=self.rng,
80
+ random_init=force_random_weights,
81
+ )
82
+
83
+ self.layers = []
84
+
85
+ for i in range(self.num_layers):
86
+ use_attention_rope = True
87
+
88
+ custom_module = DenseFFW(dtype=self.dtype,
89
+ hidden_act=hidden_act,
90
+ hidden_size=self.hidden_size,
91
+ intermediate_size=intermediate_size,
92
+ random_init=force_random_weights,
93
+ rngs=self.rng,
94
+ df_sharding=P(None, 'model'),
95
+ fd_sharding=P('model', None),
96
+ activation_ffw_td=P('data', None))
97
+
98
+ attn = Llama4Attention(
99
+ hidden_size=self.hidden_size,
100
+ dtype=self.dtype,
101
+ num_attention_heads=self.num_attention_heads,
102
+ num_key_value_heads=self.num_key_value_heads,
103
+ head_dim=self.head_dim,
104
+ rope_theta=self.rope_theta_text,
105
+ rope_scaling={
106
+ "scale_factor":
107
+ self.rope_scaling["factor"],
108
+ "low_freq_factor":
109
+ self.rope_scaling["low_freq_factor"],
110
+ "high_freq_factor":
111
+ self.rope_scaling["high_freq_factor"],
112
+ "original_max_position_embeddings":
113
+ self.rope_scaling["original_max_position_embeddings"]
114
+ },
115
+ rngs=self.rng,
116
+ rope_input_ordering="interleaved",
117
+ # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
118
+ kv_cache_dtype=vllm_config.cache_config.cache_dtype,
119
+ temperature_tuning=True,
120
+ temperature_tuning_scale=0.1,
121
+ temperature_tuning_floor_scale=8192,
122
+ use_qk_norm=self.use_qk_norm,
123
+ attention_chunk_size=None if use_attention_rope else 8192,
124
+ mesh=self.mesh,
125
+ random_init=force_random_weights,
126
+ activation_attention_td=('data', 'model'),
127
+ activation_q_td=('data', 'model'),
128
+ query_tnh=P('data', 'model', None),
129
+ keyvalue_skh=P('data', 'model', None),
130
+ activation_attention_out_td=('data', 'model'),
131
+ attn_o_tnh=P('data', 'model', None),
132
+ dnh_sharding=(None, 'model', None),
133
+ dkh_sharding=(None, 'model', None),
134
+ nhd_sharding=('model', None, None),
135
+ )
136
+
137
+ pre_attention_norm = RMSNorm(
138
+ dims=self.hidden_size,
139
+ random_init=force_random_weights,
140
+ epsilon=rms_norm_eps,
141
+ rngs=self.rng,
142
+ activation_ffw_td=('data', None),
143
+ with_scale=True,
144
+ dtype=self.dtype,
145
+ )
146
+
147
+ pre_mlp_norm = RMSNorm(
148
+ dims=self.hidden_size,
149
+ activation_ffw_td=('data', None),
150
+ epsilon=rms_norm_eps,
151
+ rngs=self.rng,
152
+ with_scale=True,
153
+ dtype=self.dtype,
154
+ random_init=force_random_weights,
155
+ )
156
+
157
+ block = TransformerBlock(custom_module=custom_module,
158
+ attn=attn,
159
+ pre_attention_norm=pre_attention_norm,
160
+ pre_mlp_norm=pre_mlp_norm,
161
+ use_attention_rope=use_attention_rope)
162
+ self.layers.append(block)
163
+
164
+ self.final_norm = RMSNorm(
165
+ dims=self.hidden_size,
166
+ activation_ffw_td=P(),
167
+ epsilon=rms_norm_eps,
168
+ rngs=self.rng,
169
+ with_scale=True,
170
+ dtype=self.dtype,
171
+ random_init=force_random_weights,
172
+ )
173
+
174
+ self.lm_head = LMhead(vocab_size=vocab_size,
175
+ hidden_size=self.hidden_size,
176
+ dtype=self.dtype,
177
+ rngs=self.rng,
178
+ vd_sharding=(('data', 'model'), None),
179
+ dv_sharding=(None, ('data', 'model')),
180
+ random_init=force_random_weights)
181
+ if self.is_verbose:
182
+ self._print_model_architecture()
183
+
184
+ def _print_model_architecture(self):
185
+
186
+ logger.info("### Embedding ###")
187
+ nnx.display(self.embedder)
188
+
189
+ logger.info("\n### Layers ###")
190
+ for i, layer in enumerate(self.layers):
191
+ logger.info(f"\n--- Layer {i} ---")
192
+ nnx.display(layer)
193
+
194
+ logger.info("\n### LM Head ###")
195
+ nnx.display(self.lm_head)
196
+
197
+ def load_weights(self, rng: jax.Array, cache_dir: Optional[str] = None):
198
+ self.rng = nnx.Rngs(rng)
199
+
200
+ weight_loader = LlamaGuard4WeightLoader(
201
+ vllm_config=self.vllm_config,
202
+ hidden_size=self.hidden_size,
203
+ attn_heads=self.num_attention_heads,
204
+ num_key_value_heads=self.num_key_value_heads,
205
+ attn_head_dim=self.head_dim)
206
+ weight_loader.load_weights(self)
207
+
208
+ def __call__(
209
+ self,
210
+ kv_caches: List[jax.Array],
211
+ input_ids: jax.Array,
212
+ attention_metadata: AttentionMetadata,
213
+ inputs_embeds: Optional[jax.Array] = None,
214
+ layer_metadata_tuple: Optional[Tuple] = None,
215
+ lora_metadata: Optional[Any] = None,
216
+ *args,
217
+ ) -> Tuple[List[KVCacheType], jax.Array]:
218
+ is_prefill = False
219
+
220
+ if inputs_embeds is not None:
221
+ x_TD = inputs_embeds
222
+ elif input_ids is not None:
223
+ x_TD = self.embedder.encode(input_ids)
224
+ else:
225
+ raise ValueError(
226
+ "Cannot run forward pass: Both input_ids and inputs_embeds are None."
227
+ )
228
+
229
+ for (i, block) in enumerate(self.layers):
230
+ kv_cache = kv_caches[i]
231
+ new_kv_cache, x_TD = block(x_TD, is_prefill, kv_cache,
232
+ attention_metadata)
233
+ jax.block_until_ready(x_TD)
234
+ kv_caches[i] = new_kv_cache
235
+
236
+ final_activation_TD = self.final_norm(x_TD)
237
+
238
+ return kv_caches, final_activation_TD, []
239
+
240
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
241
+ logits_TV = jnp.dot(hidden_states,
242
+ self.lm_head.input_embedding_table_DV.value)
243
+ return logits_TV
244
+
245
+ def get_input_embeddings(
246
+ self,
247
+ input_ids: jax.Array,
248
+ multimodal_embeddings: Optional[List[jax.Array]] = None
249
+ ) -> jax.Array:
250
+ """
251
+ Computes the embeddings for text input (used for input to fusion).
252
+ """
253
+ return self.embedder.encode(input_ids)
254
+
255
+
256
+ class LlamaGuard4WeightLoader:
257
+
258
+ def __init__(self, vllm_config: VllmConfig, hidden_size, attn_heads,
259
+ num_key_value_heads, attn_head_dim):
260
+ self.names_and_weights_generator = model_weights_generator(
261
+ model_name_or_path=vllm_config.model_config.model,
262
+ framework="flax",
263
+ filter_regex="language_model",
264
+ download_dir=vllm_config.load_config.download_dir)
265
+ self.is_verbose = getattr(vllm_config.additional_config, "is_verbose",
266
+ False)
267
+ self._transpose_map = {
268
+ "q_proj": (2, 0, 1),
269
+ "k_proj": (2, 0, 1),
270
+ "v_proj": (2, 0, 1),
271
+ "o_proj": (1, 2, 0),
272
+ "lm_head": (1, 0),
273
+ "feed_forward.down_proj": (1, 0),
274
+ "feed_forward.gate_proj": (1, 0),
275
+ "feed_forward.up_proj": (1, 0),
276
+ "mlp.down_proj": (1, 0),
277
+ "mlp.gate_proj": (1, 0),
278
+ "mlp.up_proj": (1, 0),
279
+ }
280
+ self._weight_shape_map = {
281
+ "q_proj": (attn_heads, attn_head_dim, hidden_size),
282
+ "k_proj": (num_key_value_heads, attn_head_dim, hidden_size),
283
+ "v_proj": (num_key_value_heads, attn_head_dim, hidden_size),
284
+ "o_proj": (hidden_size, attn_heads, attn_head_dim),
285
+ }
286
+
287
+ self._loaded_to_standardized_keys = {
288
+ "language_model.model.embed_tokens.weight":
289
+ "embedder.input_embedding_table_VD",
290
+ "language_model.lm_head.weight":
291
+ "lm_head.input_embedding_table_DV",
292
+ "language_model.model.norm.weight":
293
+ "final_norm.scale",
294
+ "language_model.model.layers.*.input_layernorm.weight":
295
+ "layers.*.pre_attention_norm.scale",
296
+ "language_model.model.layers.*.post_attention_layernorm.weight":
297
+ "layers.*.pre_mlp_norm.scale",
298
+ "language_model.model.layers.*.self_attn.q_proj.weight":
299
+ "layers.*.attn.kernel_q_proj_DNH",
300
+ "language_model.model.layers.*.self_attn.k_proj.weight":
301
+ "layers.*.attn.kernel_k_proj_DKH",
302
+ "language_model.model.layers.*.self_attn.v_proj.weight":
303
+ "layers.*.attn.kernel_v_proj_DKH",
304
+ "language_model.model.layers.*.self_attn.o_proj.weight":
305
+ "layers.*.attn.kernel_o_proj_NHD",
306
+ "language_model.model.layers.*.feed_forward.gate_proj.weight":
307
+ "layers.*.custom_module.kernel_gating_DF",
308
+ "language_model.model.layers.*.feed_forward.up_proj.weight":
309
+ "layers.*.custom_module.kernel_up_proj_DF",
310
+ "language_model.model.layers.*.feed_forward.down_proj.weight":
311
+ "layers.*.custom_module.kernel_down_proj_FD",
312
+ }
313
+
314
+ def map_loaded_to_standardized_name(self, loaded_key: str) -> str:
315
+ if "layer" in loaded_key:
316
+ layer_num = re.search(r"layers\.(\d+)", loaded_key).group(1)
317
+ layer_key = re.sub(r"layers\.\d+", "layers.*", loaded_key)
318
+ mapped_key = self._loaded_to_standardized_keys.get(
319
+ layer_key, loaded_key)
320
+ mapped_key = re.sub(r"layers\.\*", f"layers.{layer_num}",
321
+ mapped_key)
322
+ else:
323
+ mapped_key = self._loaded_to_standardized_keys.get(
324
+ loaded_key, loaded_key)
325
+ return mapped_key
326
+
327
+ def load_weights(self, model_for_loading: nnx.Module):
328
+ model_params = nnx.state(model_for_loading)
329
+ with jax.default_device(jax.devices("cpu")[0]):
330
+ for loaded_name, loaded_weight in self.names_and_weights_generator:
331
+ if loaded_name.endswith(".bias"):
332
+ continue
333
+ if "vision_model" in loaded_name or "multi_modal_projector" in loaded_name:
334
+ continue
335
+
336
+ mapped_name = self.map_loaded_to_standardized_name(loaded_name)
337
+ model_weight = get_param(model_params, mapped_name)
338
+
339
+ if not loaded_name.endswith(".bias"):
340
+ # For other layers, continue to use the transpose_params helper.
341
+ loaded_weight = reshape_params(loaded_name, loaded_weight,
342
+ self._weight_shape_map)
343
+ loaded_weight = transpose_params(loaded_name,
344
+ loaded_weight,
345
+ self._transpose_map)
346
+ if model_weight.value.shape != loaded_weight.shape:
347
+ raise ValueError(
348
+ f"Loaded shape for {loaded_name}: {loaded_weight.shape} "
349
+ f"does not match model shape for {mapped_name}: {model_weight.value.shape}!"
350
+ )
351
+ logger.debug(
352
+ f"Transformed parameter {loaded_name} to {mapped_name}: {loaded_weight.shape} --> {model_weight.value.shape}"
353
+ )
354
+
355
+ model_weight.value = shard_put(loaded_weight,
356
+ model_weight.sharding,
357
+ mesh=model_for_loading.mesh)
358
+ if self.is_verbose:
359
+ print_param_info(model_weight, loaded_name)
360
+
361
+ nnx.update(model_for_loading, model_params)
@@ -8,8 +8,8 @@ from transformers import Qwen2Config, modeling_flax_utils
8
8
  from vllm.config import VllmConfig
9
9
 
10
10
  from tpu_inference import utils
11
+ from tpu_inference.layers.common.attention_interface import attention
11
12
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
12
- from tpu_inference.layers.jax.attention_interface import attention
13
13
  from tpu_inference.layers.jax.rope_interface import apply_rope
14
14
  from tpu_inference.logger import init_logger
15
15
  from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
@@ -368,7 +368,8 @@ class Qwen2ForCausalLM(nnx.Module):
368
368
  "lm_head": "model.lm_head",
369
369
  })
370
370
 
371
- metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
371
+ metadata_map = get_default_maps(self.vllm_config.model_config,
372
+ self.mesh, mappings)
372
373
  load_hf_weights(vllm_config=self.vllm_config,
373
374
  model=self,
374
375
  metadata_map=metadata_map,
@@ -14,9 +14,9 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
14
14
  from vllm.config import VllmConfig
15
15
 
16
16
  from tpu_inference import utils as utils
17
- from tpu_inference.layers.common.attention_metadata import AttentionMetadata
18
- from tpu_inference.layers.jax.attention_interface import \
17
+ from tpu_inference.layers.common.attention_interface import \
19
18
  sharded_flash_attention
19
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
20
20
  from tpu_inference.logger import init_logger
21
21
  from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
22
22
  # from vllm.model_executor.models.interfaces import MultiModalEmbeddings
@@ -1061,7 +1061,8 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
1061
1061
  "lm_head": "language_model.model.lm_head",
1062
1062
  })
1063
1063
 
1064
- metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
1064
+ metadata_map = get_default_maps(self.vllm_config.model_config,
1065
+ self.mesh, mappings)
1065
1066
  load_hf_weights(vllm_config=self.vllm_config,
1066
1067
  model=self,
1067
1068
  metadata_map=metadata_map,
@@ -8,8 +8,8 @@ from transformers import Qwen3Config
8
8
  from vllm.config import VllmConfig
9
9
 
10
10
  from tpu_inference import utils
11
+ from tpu_inference.layers.common.attention_interface import attention
11
12
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
12
- from tpu_inference.layers.jax.attention_interface import attention
13
13
  from tpu_inference.layers.jax.rope_interface import apply_rope
14
14
  from tpu_inference.logger import init_logger
15
15
  from tpu_inference.models.jax.qwen2 import Qwen2DecoderLayer
@@ -295,7 +295,8 @@ class Qwen3ForCausalLM(nnx.Module):
295
295
  "lm_head": "model.lm_head",
296
296
  })
297
297
 
298
- metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
298
+ metadata_map = get_default_maps(self.vllm_config.model_config,
299
+ self.mesh, mappings)
299
300
  load_hf_weights(vllm_config=self.vllm_config,
300
301
  model=self,
301
302
  metadata_map=metadata_map,
@@ -18,7 +18,7 @@ from jax.sharding import Mesh, NamedSharding
18
18
  from jax.sharding import PartitionSpec as P
19
19
  from safetensors import safe_open
20
20
 
21
- from tpu_inference import utils
21
+ from tpu_inference import envs, utils
22
22
  from tpu_inference.logger import init_logger
23
23
  from tpu_inference.models.jax.utils import file_utils
24
24
 
@@ -197,12 +197,11 @@ def shard_put(x: jax.Array, shardings, mesh: jax.sharding.Mesh) -> jax.Array:
197
197
  return jax.device_put(x, shardings)
198
198
 
199
199
 
200
- def get_default_maps(vllm_config, mesh: Mesh,
200
+ def get_default_maps(model_config, mesh: Mesh,
201
201
  name_map: dict[str, str]) -> MetadataMap:
202
202
  """Load weights from one model weights file to the model, run on single thread."""
203
203
  sharding_size = mesh.shape["model"]
204
204
 
205
- model_config = vllm_config.model_config
206
205
  hf_config = model_config.hf_config
207
206
 
208
207
  num_heads = hf_config.num_attention_heads
@@ -273,7 +272,8 @@ def _load_hf_weights_on_thread(vllm_config,
273
272
  weights_file: str,
274
273
  filter_regex: str | None = None,
275
274
  keep_original_dtype_keys_regex: list[str]
276
- | None = None):
275
+ | None = None,
276
+ exclude_regex: list[str] | None = None):
277
277
  name_map = metadata_map.name_map
278
278
  reshape_keys = metadata_map.reshape_map
279
279
  bias_reshape_keys = metadata_map.bias_reshape_map
@@ -298,6 +298,18 @@ def _load_hf_weights_on_thread(vllm_config,
298
298
  for hf_key, hf_weight in model_weights_single_file_generator(
299
299
  weights_file, framework="flax", filter_regex=filter_regex):
300
300
 
301
+ # Check if the key should be excluded
302
+ if exclude_regex:
303
+ should_exclude = False
304
+ for pattern in exclude_regex:
305
+ if re.search(pattern, hf_key):
306
+ logger.info(
307
+ f"Excluding {hf_key} based on pattern {pattern}")
308
+ should_exclude = True
309
+ break
310
+ if should_exclude:
311
+ continue
312
+
301
313
  # Check if the key should retain its original dtype
302
314
  keep_original_dtype = False
303
315
  if keep_original_dtype_keys_regex:
@@ -408,7 +420,8 @@ def load_hf_weights(vllm_config,
408
420
  mesh: Mesh,
409
421
  filter_regex: str | None = None,
410
422
  is_draft_model: bool = False,
411
- keep_original_dtype_keys_regex: list[str] | None = None):
423
+ keep_original_dtype_keys_regex: list[str] | None = None,
424
+ exclude_regex: list[str] | None = None):
412
425
  """Load weights from all model weights files to the model, run in multi threads."""
413
426
  if is_draft_model:
414
427
  model_path = vllm_config.speculative_config.draft_model_config.model
@@ -421,7 +434,7 @@ def load_hf_weights(vllm_config,
421
434
  # NOTE(xiang): Disable multi-threading mode if running on multi-host.
422
435
  # Because multi-threading would cause different JAX processes to load
423
436
  # different weights at the same time.
424
- if os.environ.get("TPU_MULTIHOST_BACKEND", "").lower() == "ray":
437
+ if envs.TPU_MULTIHOST_BACKEND == "ray":
425
438
  max_workers = 1
426
439
  with ThreadPoolExecutor(max_workers=max_workers) as executor:
427
440
  futures = [
@@ -433,8 +446,8 @@ def load_hf_weights(vllm_config,
433
446
  mesh,
434
447
  weights_file,
435
448
  filter_regex=filter_regex,
436
- keep_original_dtype_keys_regex=keep_original_dtype_keys_regex)
437
- for weights_file in weights_files
449
+ keep_original_dtype_keys_regex=keep_original_dtype_keys_regex,
450
+ exclude_regex=exclude_regex) for weights_file in weights_files
438
451
  ]
439
452
  for future in futures:
440
453
  future.result()
@@ -25,6 +25,8 @@ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
25
25
  from tpu_inference.layers.vllm.quantization import get_tpu_quantization_config
26
26
  from tpu_inference.layers.vllm.sharding import shard_model_to_tpu
27
27
  from tpu_inference.logger import init_logger
28
+ from tpu_inference.models.jax.jax_intermediate_tensor import \
29
+ JaxIntermediateTensors
28
30
  from tpu_inference.models.vllm.vllm_model_wrapper_context import (
29
31
  get_vllm_model_wrapper_context, set_vllm_model_wrapper_context)
30
32
  from tpu_inference.runner.lora_utils import replace_lora_metadata
@@ -89,13 +91,14 @@ class VllmModelWrapper:
89
91
  slice_config = self.vllm_config.device_config.slice
90
92
  modified_slice_config = True
91
93
  self.vllm_config.device_config.slice = None
94
+ self.vllm_config.compilation_config.static_forward_context.clear()
95
+
92
96
  vllm_config_for_load = copy.deepcopy(self.vllm_config)
93
97
  if modified_slice_config:
94
98
  self.vllm_config.device_config.slice = slice_config
95
99
  assert self.vllm_config.model_config.dtype in TORCH_DTYPE_TO_JAX, "The model_config.dtype must be a PyTorch dtype."
96
100
  vllm_config_for_load.device_config.device = "cpu"
97
101
  # Clearing the cached compilation config, otherwise vllm model init will fail
98
- vllm_config_for_load.compilation_config.static_forward_context.clear()
99
102
 
100
103
  # When expert parallelism is enabled, vLLM loads weight in sharding
101
104
  # aware manner. Since tpu-inference has its own sharding logic, this
@@ -117,7 +120,7 @@ class VllmModelWrapper:
117
120
 
118
121
  # Load the vLLM model and wrap it into a new model whose forward
119
122
  # function can calculate the hidden_state and logits.
120
- with load_context:
123
+ with load_context, jax.default_device(jax.devices("cpu")[0]):
121
124
  vllm_model = vllm_get_model(vllm_config=vllm_config_for_load)
122
125
  lora_manager = None
123
126
  if vllm_config_for_load.lora_config is not None:
@@ -149,7 +152,8 @@ class VllmModelWrapper:
149
152
  "xla_tpu_reduce_scatter_collective_matmul_mode":
150
153
  "post_spmd_conservative"
151
154
  },
152
- static_argnames=("layer_name_to_kvcache_index", ),
155
+ static_argnames=("layer_name_to_kvcache_index", "is_first_rank",
156
+ "is_last_rank"),
153
157
  )
154
158
  def step_fun(
155
159
  params_and_buffers, # This has been wrapped into torchax TorchValue
@@ -157,8 +161,12 @@ class VllmModelWrapper:
157
161
  input_ids: jax.Array,
158
162
  attn_metadata: AttentionMetadata,
159
163
  input_embeds: jax.Array,
164
+ input_positions: jax.Array,
160
165
  layer_name_to_kvcache_index: Sequence[Tuple[str, int]],
161
166
  lora_metadata,
167
+ intermediate_tensors: JaxIntermediateTensors = None,
168
+ is_first_rank: bool = True,
169
+ is_last_rank: bool = True,
162
170
  *args,
163
171
  ) -> Tuple[List[jax.Array], jax.Array]:
164
172
  layer_name_to_kvcache_index = dict(layer_name_to_kvcache_index)
@@ -173,12 +181,14 @@ class VllmModelWrapper:
173
181
  # torch_view in order to call the Torch function.
174
182
  original_lora_metadata = replace_lora_metadata(
175
183
  self.model, lora_metadata, self.vllm_config.lora_config)
176
- hidden_states = torch.func.functional_call(
184
+ if not is_first_rank:
185
+ intermediate_tensors = intermediate_tensors.to_torch()
186
+ output_from_torch = torch.func.functional_call(
177
187
  self.model,
178
188
  torch_view(params_and_buffers),
179
189
  kwargs={
180
190
  "input_ids": torch_view(input_ids),
181
- "positions": torch_view(attn_metadata.input_positions),
191
+ "positions": torch_view(input_positions),
182
192
  "intermediate_tensors": None,
183
193
  "inputs_embeds": None,
184
194
  },
@@ -188,11 +198,13 @@ class VllmModelWrapper:
188
198
  self.vllm_config.lora_config)
189
199
  vllm_model_wrapper_context = get_vllm_model_wrapper_context()
190
200
  new_kv_caches = vllm_model_wrapper_context.kv_caches
191
- # Wrap the hidden_states from torch land into a JaxValue for the jax
192
- # code to consume.
193
- hidden_states = jax_view(hidden_states)
194
-
195
- return new_kv_caches, hidden_states, []
201
+ # Wrap the output(hidden states or intermediate tensor)
202
+ # from torch land into a JaxValue for the jax code to consume.
203
+ if not is_last_rank:
204
+ output = JaxIntermediateTensors.from_torch(output_from_torch)
205
+ else:
206
+ output = jax_view(output_from_torch)
207
+ return new_kv_caches, output, []
196
208
 
197
209
  return step_fun
198
210
 
@@ -1,7 +1,6 @@
1
1
  # SPDX-License-Identifier: Apache-2.0
2
2
 
3
- import os
4
- from typing import TYPE_CHECKING, Optional, Tuple, Union, cast
3
+ from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast
5
4
 
6
5
  import jax.numpy as jnp
7
6
  import vllm.envs as vllm_envs
@@ -12,7 +11,7 @@ from vllm.platforms.interface import Platform, PlatformEnum
12
11
  from vllm.sampling_params import SamplingParams, SamplingType
13
12
 
14
13
  from tpu_inference import envs
15
- from tpu_inference.layers.jax.sharding import ShardingConfigManager
14
+ from tpu_inference.layers.common.sharding import ShardingConfigManager
16
15
  from tpu_inference.logger import init_logger
17
16
 
18
17
  if TYPE_CHECKING:
@@ -57,7 +56,8 @@ class TpuPlatform(Platform):
57
56
  def get_attn_backend_cls(cls, selected_backend: "_Backend", head_size: int,
58
57
  dtype: jnp.dtype, kv_cache_dtype: Optional[str],
59
58
  block_size: int, use_v1: bool, use_mla: bool,
60
- has_sink: bool, use_sparse: bool) -> str:
59
+ has_sink: bool, use_sparse: bool,
60
+ attn_type: Any) -> str:
61
61
  from vllm.attention.backends.registry import _Backend
62
62
  if selected_backend != _Backend.PALLAS:
63
63
  logger.info("Cannot use %s backend on TPU.", selected_backend)
@@ -182,10 +182,16 @@ class TpuPlatform(Platform):
182
182
  parallel_config.worker_cls = \
183
183
  "tpu_inference.worker.tpu_worker.TPUWorker"
184
184
 
185
- multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
185
+ multihost_backend = envs.TPU_MULTIHOST_BACKEND
186
186
  if not multihost_backend: # Single host
187
- logger.info("Force using UniProcExecutor for JAX on single host.")
188
- parallel_config.distributed_executor_backend = "uni"
187
+ if parallel_config.pipeline_parallel_size == 1:
188
+ logger.info("Force using UniProcExecutor for JAX on \
189
+ single host without pipeline parallelism.")
190
+ parallel_config.distributed_executor_backend = "uni"
191
+ else:
192
+ logger.info("Force using MultiprocExecutor for JAX on \
193
+ single host with pipeline parallelism.")
194
+ parallel_config.distributed_executor_backend = "mp"
189
195
  elif multihost_backend == "ray":
190
196
  from tpu_inference.executors.ray_distributed_executor import \
191
197
  RayDistributedExecutor
@@ -260,3 +266,7 @@ class TpuPlatform(Platform):
260
266
  Returns if the current platform needs to sync weight loader.
261
267
  """
262
268
  return True
269
+
270
+ @classmethod
271
+ def support_hybrid_kv_cache(cls) -> bool:
272
+ return True