tpu-inference 0.0.1rc1__py3-none-any.whl → 0.11.1.dev202511130813__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (67) hide show
  1. tests/kernels/fused_moe_v1_test.py +34 -303
  2. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +2 -2
  3. tests/lora/test_layers.py +6 -0
  4. tests/lora/utils.py +8 -0
  5. tests/test_utils.py +16 -24
  6. tpu_inference/__init__.py +3 -22
  7. tpu_inference/core/core_tpu.py +9 -17
  8. tpu_inference/core/disagg_utils.py +8 -6
  9. tpu_inference/distributed/tpu_connector.py +4 -3
  10. tpu_inference/distributed/utils.py +2 -3
  11. tpu_inference/envs.py +8 -61
  12. tpu_inference/executors/ray_distributed_executor.py +11 -31
  13. tpu_inference/kernels/fused_moe/v1/kernel.py +110 -641
  14. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +54 -77
  15. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +143 -287
  16. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +0 -7
  17. tpu_inference/layers/jax/attention/attention.py +1 -1
  18. tpu_inference/layers/{common → jax}/attention_interface.py +2 -8
  19. tpu_inference/layers/jax/sample/rejection_sampler.py +1 -1
  20. tpu_inference/layers/jax/sample/sampling.py +2 -2
  21. tpu_inference/layers/{common → jax}/sharding.py +5 -5
  22. tpu_inference/layers/vllm/attention.py +1 -1
  23. tpu_inference/layers/vllm/fused_moe.py +208 -170
  24. tpu_inference/layers/vllm/quantization/__init__.py +3 -7
  25. tpu_inference/layers/vllm/quantization/awq.py +3 -4
  26. tpu_inference/layers/vllm/quantization/common.py +1 -6
  27. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +2 -4
  28. tpu_inference/layers/vllm/quantization/unquantized.py +67 -62
  29. tpu_inference/layers/vllm/sharding.py +2 -2
  30. tpu_inference/lora/torch_punica_tpu.py +2 -1
  31. tpu_inference/mock/__init__.py +0 -0
  32. tpu_inference/mock/vllm_config_utils.py +28 -0
  33. tpu_inference/mock/vllm_envs.py +1219 -0
  34. tpu_inference/mock/vllm_logger.py +212 -0
  35. tpu_inference/mock/vllm_logging_utils.py +15 -0
  36. tpu_inference/models/common/model_loader.py +12 -46
  37. tpu_inference/models/jax/llama3.py +3 -4
  38. tpu_inference/models/jax/llama_eagle3.py +5 -8
  39. tpu_inference/models/jax/phi3.py +376 -0
  40. tpu_inference/models/jax/qwen2.py +2 -3
  41. tpu_inference/models/jax/qwen2_5_vl.py +50 -165
  42. tpu_inference/models/jax/qwen3.py +2 -3
  43. tpu_inference/models/jax/utils/quantization/quantization_utils.py +6 -3
  44. tpu_inference/models/jax/utils/weight_utils.py +143 -198
  45. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -32
  46. tpu_inference/platforms/tpu_platform.py +34 -47
  47. tpu_inference/runner/compilation_manager.py +60 -145
  48. tpu_inference/runner/kv_cache.py +2 -2
  49. tpu_inference/runner/kv_cache_manager.py +18 -17
  50. tpu_inference/runner/persistent_batch_manager.py +2 -40
  51. tpu_inference/runner/structured_decoding_manager.py +3 -2
  52. tpu_inference/runner/tpu_runner.py +135 -283
  53. tpu_inference/runner/utils.py +2 -2
  54. tpu_inference/spec_decode/jax/eagle3.py +21 -71
  55. tpu_inference/tpu_info.py +3 -4
  56. tpu_inference/utils.py +15 -38
  57. tpu_inference/worker/tpu_worker.py +26 -163
  58. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/METADATA +3 -4
  59. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/RECORD +63 -61
  60. tests/test_envs.py +0 -203
  61. tpu_inference/layers/common/quant_methods.py +0 -8
  62. tpu_inference/layers/vllm/quantization/mxfp4.py +0 -331
  63. tpu_inference/models/jax/llama_guard_4.py +0 -361
  64. /tpu_inference/layers/{common → jax}/binary_search.py +0 -0
  65. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/WHEEL +0 -0
  66. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/licenses/LICENSE +0 -0
  67. {tpu_inference-0.0.1rc1.dist-info → tpu_inference-0.11.1.dev202511130813.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,376 @@
1
+ from typing import List, Optional, Tuple
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from flax import nnx
6
+ from jax.sharding import Mesh
7
+ from transformers import Phi3Config, modeling_flax_utils
8
+ from vllm.config import VllmConfig
9
+
10
+ from tpu_inference import utils
11
+ from tpu_inference.layers.common.attention_metadata import AttentionMetadata
12
+ from tpu_inference.layers.jax.attention_interface import attention
13
+ from tpu_inference.layers.jax.rope_interface import apply_longrope, apply_rope
14
+ from tpu_inference.logger import init_logger
15
+ from tpu_inference.models.jax.utils.weight_utils import (MetadataMap,
16
+ load_hf_weights)
17
+
18
+ logger = init_logger(__name__)
19
+
20
+ init_fn = nnx.initializers.uniform()
21
+
22
+
23
+ class Phi3MLP(nnx.Module):
24
+
25
+ def __init__(self, config: Phi3Config, dtype: jnp.dtype, rng: nnx.Rngs):
26
+ hidden_size = config.hidden_size
27
+ intermediate_size = config.intermediate_size
28
+ act = config.hidden_act
29
+
30
+ self.gate_up_proj = nnx.Linear(
31
+ hidden_size,
32
+ 2 * intermediate_size,
33
+ use_bias=False,
34
+ param_dtype=dtype,
35
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model")),
36
+ rngs=rng,
37
+ )
38
+ self.down_proj = nnx.Linear(
39
+ intermediate_size,
40
+ hidden_size,
41
+ use_bias=False,
42
+ param_dtype=dtype,
43
+ kernel_init=nnx.with_partitioning(init_fn, ("model", None)),
44
+ rngs=rng,
45
+ )
46
+ self.act_fn = modeling_flax_utils.ACT2FN[act]
47
+
48
+ def __call__(self, x: jax.Array) -> jax.Array:
49
+ gate_up = self.gate_up_proj(x)
50
+ gate, up = jnp.split(gate_up, 2, axis=-1)
51
+ fuse = up * self.act_fn(gate)
52
+ result = self.down_proj(fuse)
53
+ return result
54
+
55
+
56
+ class Phi3Attention(nnx.Module):
57
+
58
+ def __init__(self, config: Phi3Config, dtype: jnp.dtype, rng: nnx.Rngs,
59
+ mesh: Mesh, kv_cache_dtype: str):
60
+ self.hidden_size = config.hidden_size
61
+ self.num_heads = config.num_attention_heads
62
+ self.num_kv_heads = config.num_key_value_heads
63
+ self.rope_theta = config.rope_theta
64
+ self.rope_scaling = config.rope_scaling
65
+ self.original_max_position_embeddings = config.original_max_position_embeddings
66
+ self.max_position_embeddings = config.max_position_embeddings
67
+
68
+ self.head_dim_original = getattr(config, "head_dim",
69
+ self.hidden_size // self.num_heads)
70
+ self.head_dim = utils.get_padded_head_dim(self.head_dim_original)
71
+
72
+ sharding_size = mesh.shape["model"]
73
+ self.num_heads = utils.get_padded_num_heads(self.num_heads,
74
+ sharding_size)
75
+ self.num_kv_heads = utils.get_padded_num_heads(self.num_kv_heads,
76
+ sharding_size)
77
+
78
+ self.mesh = mesh
79
+
80
+ self.qkv_proj = nnx.Einsum(
81
+ "TD,DNH->TNH",
82
+ (self.hidden_size, self.num_heads + self.num_kv_heads * 2,
83
+ self.head_dim),
84
+ param_dtype=dtype,
85
+ kernel_init=nnx.with_partitioning(init_fn, (None, "model", None)),
86
+ rngs=rng,
87
+ )
88
+ self.o_proj = nnx.Einsum(
89
+ "TNH,NHD->TD",
90
+ (self.num_heads, self.head_dim, self.hidden_size),
91
+ param_dtype=dtype,
92
+ kernel_init=nnx.with_partitioning(init_fn, ("model", None, None)),
93
+ rngs=rng,
94
+ )
95
+
96
+ self._q_scale = 1.0
97
+ self._k_scale = 1.0
98
+ self._v_scale = 1.0
99
+ self.kv_cache_quantized_dtype = None
100
+ if kv_cache_dtype != "auto":
101
+ self.kv_cache_quantized_dtype = utils.get_jax_dtype_from_str_dtype(
102
+ kv_cache_dtype)
103
+
104
+ def __call__(
105
+ self,
106
+ kv_cache: Optional[jax.Array],
107
+ x: jax.Array,
108
+ attention_metadata: AttentionMetadata,
109
+ ) -> Tuple[jax.Array, jax.Array]:
110
+ md = attention_metadata
111
+ # qkv: (T, N + K * 2, H)
112
+ qkv = self.qkv_proj(x)
113
+ q, k, v = jnp.split(
114
+ qkv, [self.num_heads, self.num_heads + self.num_kv_heads], axis=1)
115
+ if self.rope_scaling:
116
+ q = apply_longrope(q, md.input_positions, self.head_dim_original,
117
+ self.rope_scaling,
118
+ self.original_max_position_embeddings,
119
+ self.max_position_embeddings, self.rope_theta)
120
+ k = apply_longrope(k, md.input_positions, self.head_dim_original,
121
+ self.rope_scaling,
122
+ self.original_max_position_embeddings,
123
+ self.max_position_embeddings, self.rope_theta)
124
+ else:
125
+ q = apply_rope(q, md.input_positions, self.head_dim_original,
126
+ self.rope_theta, self.rope_scaling)
127
+ k = apply_rope(k, md.input_positions, self.head_dim_original,
128
+ self.rope_theta, self.rope_scaling)
129
+ # o: (T, N, H)
130
+ q_scale = k_scale = v_scale = None
131
+ if self.kv_cache_quantized_dtype:
132
+ # TODO(kyuyeunk/jacobplatin): Enable w8a8 when VREG spill issue is resolved.
133
+ # q_scale = self._q_scale
134
+ k_scale = self._k_scale
135
+ v_scale = self._v_scale
136
+ k, v = utils.quantize_kv(k, v, self.kv_cache_quantized_dtype,
137
+ k_scale, v_scale)
138
+ new_kv_cache, outputs = attention(
139
+ kv_cache,
140
+ q,
141
+ k,
142
+ v,
143
+ attention_metadata,
144
+ self.mesh,
145
+ self.head_dim_original,
146
+ q_scale=q_scale,
147
+ k_scale=k_scale,
148
+ v_scale=v_scale,
149
+ )
150
+ # (T, D)
151
+ o = self.o_proj(outputs)
152
+ return new_kv_cache, o
153
+
154
+
155
+ class Phi3DecoderLayer(nnx.Module):
156
+
157
+ def __init__(self, config: Phi3Config, dtype: jnp.dtype, rng: nnx.Rngs,
158
+ mesh: Mesh, kv_cache_dtype: str):
159
+ rms_norm_eps = config.rms_norm_eps
160
+ hidden_size = config.hidden_size
161
+
162
+ self.input_layernorm = nnx.RMSNorm(
163
+ hidden_size,
164
+ epsilon=rms_norm_eps,
165
+ param_dtype=dtype,
166
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
167
+ rngs=rng,
168
+ )
169
+ self.self_attn = Phi3Attention(config=config,
170
+ dtype=dtype,
171
+ rng=rng,
172
+ mesh=mesh,
173
+ kv_cache_dtype=kv_cache_dtype)
174
+ self.post_attention_layernorm = nnx.RMSNorm(
175
+ hidden_size,
176
+ epsilon=rms_norm_eps,
177
+ param_dtype=dtype,
178
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
179
+ rngs=rng,
180
+ )
181
+ self.mlp = Phi3MLP(
182
+ config=config,
183
+ dtype=dtype,
184
+ rng=rng,
185
+ )
186
+
187
+ def __call__(
188
+ self,
189
+ kv_cache: jax.Array,
190
+ x: jax.Array,
191
+ attention_metadata: AttentionMetadata,
192
+ ) -> Tuple[jax.Array, jax.Array]:
193
+ hidden_states = self.input_layernorm(x)
194
+ kv_cache, attn_output = self.self_attn(
195
+ kv_cache,
196
+ hidden_states,
197
+ attention_metadata,
198
+ )
199
+ attn_output += x
200
+
201
+ residual = attn_output
202
+ attn_output = self.post_attention_layernorm(attn_output)
203
+ outputs = self.mlp(attn_output)
204
+ outputs = residual + outputs
205
+ return kv_cache, outputs
206
+
207
+
208
+ class Phi3Model(nnx.Module):
209
+
210
+ def __init__(self, vllm_config: VllmConfig, rng: nnx.Rngs,
211
+ mesh: Mesh) -> None:
212
+ model_config = vllm_config.model_config
213
+ hf_config = model_config.hf_config
214
+ vocab_size = model_config.get_vocab_size()
215
+ dtype = model_config.dtype
216
+ rms_norm_eps = hf_config.rms_norm_eps
217
+ hidden_size = hf_config.hidden_size
218
+
219
+ self.embed = nnx.Embed(
220
+ num_embeddings=vocab_size,
221
+ features=hidden_size,
222
+ param_dtype=dtype,
223
+ embedding_init=nnx.with_partitioning(init_fn, ("model", None)),
224
+ rngs=rng,
225
+ )
226
+ self.layers = [
227
+ Phi3DecoderLayer(
228
+ config=hf_config,
229
+ dtype=dtype,
230
+ rng=rng,
231
+ mesh=mesh,
232
+ # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
233
+ kv_cache_dtype=vllm_config.cache_config.cache_dtype)
234
+ for _ in range(hf_config.num_hidden_layers)
235
+ ]
236
+ self.norm = nnx.RMSNorm(
237
+ hidden_size,
238
+ epsilon=rms_norm_eps,
239
+ param_dtype=dtype,
240
+ scale_init=nnx.with_partitioning(init_fn, (None, )),
241
+ rngs=rng,
242
+ )
243
+ if model_config.hf_config.tie_word_embeddings:
244
+ self.lm_head = self.embed.embedding
245
+ else:
246
+ self.lm_head = nnx.Param(
247
+ init_fn(rng.params(), (hidden_size, vocab_size), dtype),
248
+ sharding=(None, "model"),
249
+ )
250
+
251
+ def __call__(
252
+ self,
253
+ kv_caches: List[jax.Array],
254
+ input_ids: jax.Array,
255
+ attention_metadata: AttentionMetadata,
256
+ ) -> Tuple[List[jax.Array], jax.Array]:
257
+ x = self.embed(input_ids)
258
+ for i, layer in enumerate(self.layers):
259
+ kv_cache = kv_caches[i]
260
+ kv_cache, x = layer(
261
+ kv_cache,
262
+ x,
263
+ attention_metadata,
264
+ )
265
+ kv_caches[i] = kv_cache
266
+ x = self.norm(x)
267
+ return kv_caches, x
268
+
269
+
270
+ class Phi3ForCausalLM(nnx.Module):
271
+
272
+ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array,
273
+ mesh: Mesh) -> None:
274
+ self.vllm_config = vllm_config
275
+ self.rng = nnx.Rngs(rng_key)
276
+ self.mesh = mesh
277
+
278
+ self.model = Phi3Model(
279
+ vllm_config=vllm_config,
280
+ rng=self.rng,
281
+ mesh=mesh,
282
+ )
283
+
284
+ def __call__(
285
+ self,
286
+ kv_caches: List[jax.Array],
287
+ input_ids: jax.Array,
288
+ attention_metadata: AttentionMetadata,
289
+ *args,
290
+ ) -> Tuple[List[jax.Array], jax.Array, List[jax.Array]]:
291
+ kv_caches, x = self.model(
292
+ kv_caches,
293
+ input_ids,
294
+ attention_metadata,
295
+ )
296
+ return kv_caches, x, []
297
+
298
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
299
+ if self.vllm_config.model_config.hf_config.tie_word_embeddings:
300
+ logits = jnp.dot(hidden_states, self.model.lm_head.value.T)
301
+ else:
302
+ logits = jnp.dot(hidden_states, self.model.lm_head.value)
303
+ return logits
304
+
305
+ def get_metadata_map(self) -> MetadataMap:
306
+ sharding_size = self.mesh.shape["model"]
307
+
308
+ model_config = self.vllm_config.model_config
309
+ hf_config = model_config.hf_config
310
+
311
+ num_heads = hf_config.num_attention_heads
312
+ num_kv_heads = hf_config.num_key_value_heads
313
+ qkv_heads = num_heads + num_kv_heads * 2
314
+ hidden_size = model_config.get_hidden_size()
315
+
316
+ # Pad head_dim for kernel performance.
317
+ head_dim_original = model_config.get_head_size()
318
+
319
+ # Key: path to a HF layer weight
320
+ # Value: path to a nnx layer weight
321
+ name_map = {
322
+ "model.embed_tokens": "model.embed.embedding",
323
+ "model.layers.*.input_layernorm":
324
+ "model.layers.*.input_layernorm.scale",
325
+ "model.layers.*.mlp.down_proj":
326
+ "model.layers.*.mlp.down_proj.kernel",
327
+ "model.layers.*.mlp.gate_up_proj":
328
+ "model.layers.*.mlp.gate_up_proj.kernel",
329
+ "model.layers.*.post_attention_layernorm":
330
+ "model.layers.*.post_attention_layernorm.scale",
331
+ "model.layers.*.self_attn.qkv_proj":
332
+ "model.layers.*.self_attn.qkv_proj.kernel",
333
+ "model.layers.*.self_attn.o_proj":
334
+ "model.layers.*.self_attn.o_proj.kernel",
335
+ "model.norm": "model.norm.scale",
336
+ }
337
+ if not self.vllm_config.model_config.hf_config.tie_word_embeddings:
338
+ name_map.update({
339
+ "lm_head": "model.lm_head",
340
+ })
341
+
342
+ reshape_keys: dict[str, tuple[int, ...]] = {
343
+ "qkv_proj": (qkv_heads, head_dim_original, hidden_size),
344
+ "o_proj": (hidden_size, num_heads, head_dim_original),
345
+ }
346
+ transpose_keys: dict[str, tuple[int, ...]] = {
347
+ "lm_head": (1, 0),
348
+ "gate_up_proj": (1, 0),
349
+ "down_proj": (1, 0),
350
+ "qkv_proj": (2, 0, 1),
351
+ "o_proj": (1, 2, 0),
352
+ }
353
+
354
+ # key: (padding_dim, padding_size)
355
+ pad_keys: dict[str, tuple[int, ...]] = {
356
+ "qkv_proj": (1, sharding_size // num_heads),
357
+ "o_proj": (0, sharding_size // num_heads),
358
+ }
359
+
360
+ return MetadataMap(name_map=name_map,
361
+ reshape_map=reshape_keys,
362
+ bias_reshape_map={},
363
+ transpose_map=transpose_keys,
364
+ pad_map=pad_keys,
365
+ bias_pad_map={})
366
+
367
+ def load_weights(self, rng_key: jax.Array):
368
+ # NOTE: Since we are using nnx.eval_shape to init the model,
369
+ # we have to pass dynamic arrays here for __call__'s usage.
370
+ self.rng = nnx.Rngs(rng_key)
371
+
372
+ metadata_map = self.get_metadata_map()
373
+ load_hf_weights(vllm_config=self.vllm_config,
374
+ model=self,
375
+ metadata_map=metadata_map,
376
+ mesh=self.mesh)
@@ -8,8 +8,8 @@ from transformers import Qwen2Config, modeling_flax_utils
8
8
  from vllm.config import VllmConfig
9
9
 
10
10
  from tpu_inference import utils
11
- from tpu_inference.layers.common.attention_interface import attention
12
11
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
12
+ from tpu_inference.layers.jax.attention_interface import attention
13
13
  from tpu_inference.layers.jax.rope_interface import apply_rope
14
14
  from tpu_inference.logger import init_logger
15
15
  from tpu_inference.models.jax.utils.weight_utils import (get_default_maps,
@@ -368,8 +368,7 @@ class Qwen2ForCausalLM(nnx.Module):
368
368
  "lm_head": "model.lm_head",
369
369
  })
370
370
 
371
- metadata_map = get_default_maps(self.vllm_config.model_config,
372
- self.mesh, mappings)
371
+ metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
373
372
  load_hf_weights(vllm_config=self.vllm_config,
374
373
  model=self,
375
374
  metadata_map=metadata_map,
@@ -14,9 +14,9 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
14
14
  from vllm.config import VllmConfig
15
15
 
16
16
  from tpu_inference import utils as utils
17
- from tpu_inference.layers.common.attention_interface import \
18
- sharded_flash_attention
19
17
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
18
+ from tpu_inference.layers.jax.attention_interface import \
19
+ sharded_flash_attention
20
20
  from tpu_inference.logger import init_logger
21
21
  from tpu_inference.models.jax.qwen2 import Qwen2ForCausalLM
22
22
  # from vllm.model_executor.models.interfaces import MultiModalEmbeddings
@@ -486,11 +486,6 @@ class Qwen2_5_VisionTransformer(nnx.Module):
486
486
  dtype=dtype,
487
487
  rngs=rngs)
488
488
 
489
- additional_config = getattr(vllm_config, "additional_config",
490
- None) or {}
491
- self.enable_dynamic_image_sizes = additional_config.get(
492
- "enable_dynamic_image_sizes", False)
493
-
494
489
  def rotary_pos_emb_thw(self, t, h, w):
495
490
  hpos_ids, wpos_ids = jnp.indices((h, w))
496
491
  hpos_ids = hpos_ids.reshape(
@@ -584,7 +579,21 @@ class Qwen2_5_VisionTransformer(nnx.Module):
584
579
  seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
585
580
  return max_seqlen, seqlens
586
581
 
587
- def compute_aux_arrays(self, grid_thw: tuple[tuple[int, int, int]]):
582
+ def __call__(self, x: jax.Array, grid_thw: tuple[tuple[int, int,
583
+ int]]) -> jax.Array:
584
+ # x: pixel_values: jax.Array
585
+ # """Shape:
586
+ # `(num_patches, num_channels * patch_size * patch_size)`
587
+ # """
588
+
589
+ # grid_thw: image_grid_thw: jax.Array
590
+ # """Shape: `(num_images, 3)`
591
+ # This should be in `(grid_t, grid_h, grid_w)` format.
592
+ # """
593
+ hidden_states = self.patch_embed(x)
594
+
595
+ # num of patches
596
+ seq_len = x.shape[0]
588
597
  # num of images/videoes
589
598
  num_grids = len(grid_thw)
590
599
 
@@ -629,42 +638,6 @@ class Qwen2_5_VisionTransformer(nnx.Module):
629
638
  cu_seqlens = jnp.pad(cu_seqlens, ((1, 0), ),
630
639
  mode='constant',
631
640
  constant_values=0)
632
- return window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens
633
-
634
- def pad_inputs(self, x, window_index, rotary_pos_emb, cu_seqlens,
635
- cu_window_seqlens):
636
- # padding
637
- num_patches = int(rotary_pos_emb.shape[0])
638
- bucket_num_patches = 1 << (num_patches - 1).bit_length()
639
- num_tokens = window_index.shape[0]
640
- bucket_num_tokens = bucket_num_patches // self.spatial_merge_unit
641
- vit_merger_window_size = (self.window_size //
642
- self.spatial_merge_size // self.patch_size)
643
- max_windows = (bucket_num_tokens // vit_merger_window_size) + 2
644
-
645
- rotary_pos_emb = jnp.pad(rotary_pos_emb,
646
- ((0, bucket_num_patches - num_patches),
647
- (0, 0)))
648
- window_index = jnp.concatenate([
649
- window_index,
650
- jnp.arange(num_tokens, bucket_num_tokens, dtype=jnp.int32)
651
- ])
652
- cu_window_seqlens = jnp.append(cu_window_seqlens, bucket_num_patches)
653
- pad_w = max(0, max_windows + 1 - cu_window_seqlens.shape[0])
654
- cu_window_seqlens = jnp.pad(cu_window_seqlens, (0, pad_w), mode='edge')
655
- cu_seqlens = jnp.append(cu_seqlens, bucket_num_patches)
656
-
657
- x_padded = jnp.pad(x, ((0, bucket_num_patches - x.shape[0]), (0, 0)))
658
-
659
- return x_padded, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens, num_tokens
660
-
661
- def compute_hidden_states(self, x: jax.Array, window_index: jax.Array,
662
- rotary_pos_emb: jax.Array, cu_seqlens: jax.Array,
663
- cu_window_seqlens: jax.Array) -> jax.Array:
664
- hidden_states = self.patch_embed(x)
665
-
666
- # num of patches
667
- seq_len = x.shape[0]
668
641
 
669
642
  hidden_states = hidden_states.reshape(
670
643
  seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
@@ -691,48 +664,6 @@ class Qwen2_5_VisionTransformer(nnx.Module):
691
664
  hidden_states = hidden_states[reverse_indices, :]
692
665
  return hidden_states
693
666
 
694
- @jax.jit
695
- def encode_padded_jit(self, x_padded, window_index, rotary_pos_emb,
696
- cu_seqlens, cu_window_seqlens):
697
- return self.compute_hidden_states(x_padded, window_index,
698
- rotary_pos_emb, cu_seqlens,
699
- cu_window_seqlens)
700
-
701
- @partial(
702
- jax.jit,
703
- static_argnames=("grid_thw", ),
704
- )
705
- def encode_jit(self, x, grid_thw):
706
- window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens = self.compute_aux_arrays(
707
- grid_thw)
708
- return self.compute_hidden_states(x, window_index, rotary_pos_emb,
709
- cu_seqlens, cu_window_seqlens)
710
-
711
- def __call__(self, x: jax.Array, grid_thw: tuple[tuple[int, int,
712
- int]]) -> jax.Array:
713
- # x: pixel_values: jax.Array
714
- # """Shape:
715
- # `(num_patches, num_channels * patch_size * patch_size)`
716
- # """
717
-
718
- # grid_thw: image_grid_thw: jax.Array
719
- # """Shape: `(num_images, 3)`
720
- # This should be in `(grid_t, grid_h, grid_w)` format.
721
- # """
722
- if self.enable_dynamic_image_sizes:
723
- window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens = self.compute_aux_arrays(
724
- grid_thw)
725
- x_padded, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens, num_tokens = self.pad_inputs(
726
- x, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens)
727
-
728
- hidden_states = self.encode_padded_jit(x_padded, window_index,
729
- rotary_pos_emb, cu_seqlens,
730
- cu_window_seqlens)
731
- return hidden_states[:num_tokens]
732
-
733
- else:
734
- return self.encode_jit(x, grid_thw)
735
-
736
667
 
737
668
  class Qwen2_5_VLForConditionalGeneration(nnx.Module):
738
669
 
@@ -957,6 +888,10 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
957
888
  # "video"] = self._parse_and_validate_video_input(**kwargs)
958
889
  return mm_input_by_modality
959
890
 
891
+ @partial(
892
+ jax.jit,
893
+ static_argnames=("image_grid_thw", ),
894
+ )
960
895
  def get_single_image_embedding(self, image_pixel_values, image_grid_thw):
961
896
  return self.visual(image_pixel_values, (image_grid_thw, ))
962
897
 
@@ -1126,8 +1061,7 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
1126
1061
  "lm_head": "language_model.model.lm_head",
1127
1062
  })
1128
1063
 
1129
- metadata_map = get_default_maps(self.vllm_config.model_config,
1130
- self.mesh, mappings)
1064
+ metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
1131
1065
  load_hf_weights(vllm_config=self.vllm_config,
1132
1066
  model=self,
1133
1067
  metadata_map=metadata_map,
@@ -1137,82 +1071,33 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
1137
1071
  self,
1138
1072
  run_compilation_fn: Callable,
1139
1073
  ) -> None:
1074
+ image_shapes = []
1075
+ if (warmup_config := self.vllm_config.additional_config.get(
1076
+ "vision_warmup_config")):
1077
+ image_shapes = warmup_config.get("image_shapes")
1078
+
1140
1079
  vc = self.vllm_config.model_config.hf_config.vision_config
1141
- patch_input_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
1142
- if self.visual.enable_dynamic_image_sizes:
1143
- spatial_merge_unit = vc.spatial_merge_size**2
1144
- max_num_batched_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens
1145
- mm_kwargs = self.vllm_config.model_config.multimodal_config.mm_processor_kwargs or {}
1146
- limit_pixels = float(mm_kwargs.get("max_pixels", float('inf')))
1147
-
1148
- max_patches = int(
1149
- min(max_num_batched_tokens * spatial_merge_unit,
1150
- limit_pixels / (vc.patch_size**2)))
1151
-
1152
- num_patches_paddings = [
1153
- 1 << i for i in range(4, (max_patches - 1).bit_length() + 1)
1154
- ]
1155
- rotary_dim = vc.hidden_size // vc.num_heads // 2
1156
- vit_merger_window_size = (vc.window_size //
1157
- vc.spatial_merge_size // vc.patch_size)
1158
-
1159
- for num_patches in num_patches_paddings:
1160
- dummy_x_padded = jnp.ones(
1161
- (num_patches, patch_input_dim),
1162
- dtype=self.vllm_config.model_config.dtype)
1163
-
1164
- num_tokens = num_patches // spatial_merge_unit
1165
- dummy_window_index = jnp.arange(num_tokens, dtype=jnp.int32)
1166
-
1167
- dummy_rotary_pos_emb = jnp.ones(
1168
- (num_patches, rotary_dim),
1169
- dtype=self.vllm_config.model_config.dtype)
1170
-
1171
- dummy_cu_seqlens = jnp.array([0, num_patches, num_patches],
1172
- dtype=jnp.int32)
1173
-
1174
- max_windows = (num_tokens // vit_merger_window_size) + 2
1175
- patches_per_window = (vit_merger_window_size**
1176
- 2) * spatial_merge_unit
1177
- dummy_cu_window_seqlens = jnp.arange(
1178
- max_windows + 1, dtype=jnp.int32) * patches_per_window
1179
- dummy_cu_window_seqlens = jnp.minimum(dummy_cu_window_seqlens,
1180
- num_patches)
1181
-
1182
- run_compilation_fn("vision_encoder_padded",
1183
- self.visual.encode_padded_jit,
1184
- dummy_x_padded,
1185
- dummy_window_index,
1186
- dummy_rotary_pos_emb,
1187
- dummy_cu_seqlens,
1188
- dummy_cu_window_seqlens,
1189
- num_patches=num_patches)
1190
- else:
1191
- image_shapes = []
1192
- if (warmup_config := self.vllm_config.additional_config.get(
1193
- "vision_warmup_config")):
1194
- image_shapes = warmup_config.get("image_shapes")
1195
-
1196
- factor = vc.patch_size * vc.spatial_merge_size
1197
- for input_hw in image_shapes:
1198
- if not isinstance(input_hw, list) or len(input_hw) != 2:
1199
- logger.warning(f"Skipping invalid shape {input_hw}.")
1200
- continue
1201
- h_input, w_input = input_hw
1202
- h_processed = round(h_input / factor) * factor
1203
- w_processed = round(w_input / factor) * factor
1204
- t, h, w = 1, h_processed // vc.patch_size, w_processed // vc.patch_size
1205
- grid_thw = (t, h, w)
1206
- num_patches = t * h * w
1207
-
1208
- dummy_pixel_values = jnp.ones(
1209
- (num_patches, patch_input_dim),
1210
- self.vllm_config.model_config.dtype,
1211
- )
1212
- dummy_grid_thw = (grid_thw, )
1080
+ factor = vc.patch_size * vc.spatial_merge_size
1081
+ for input_hw in image_shapes:
1082
+ if not isinstance(input_hw, list) or len(input_hw) != 2:
1083
+ logger.warning(f"Skipping invalid shape {input_hw}.")
1084
+ continue
1085
+ h_input, w_input = input_hw
1086
+ h_processed = round(h_input / factor) * factor
1087
+ w_processed = round(w_input / factor) * factor
1088
+ t, h, w = 1, h_processed // vc.patch_size, w_processed // vc.patch_size
1089
+ grid_thw = (t, h, w)
1090
+ num_patches = t * h * w
1091
+ patch_input_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
1092
+
1093
+ dummy_pixel_values = jnp.ones(
1094
+ (num_patches, patch_input_dim),
1095
+ self.vllm_config.model_config.dtype,
1096
+ )
1097
+ dummy_grid_thw = grid_thw
1213
1098
 
1214
- run_compilation_fn("vision_encoder",
1215
- self.visual.encode_jit,
1216
- dummy_pixel_values,
1217
- dummy_grid_thw,
1218
- image_shape=input_hw)
1099
+ run_compilation_fn("single_image_encoder",
1100
+ self.get_single_image_embedding,
1101
+ dummy_pixel_values,
1102
+ dummy_grid_thw,
1103
+ image_shape=input_hw)
@@ -8,8 +8,8 @@ from transformers import Qwen3Config
8
8
  from vllm.config import VllmConfig
9
9
 
10
10
  from tpu_inference import utils
11
- from tpu_inference.layers.common.attention_interface import attention
12
11
  from tpu_inference.layers.common.attention_metadata import AttentionMetadata
12
+ from tpu_inference.layers.jax.attention_interface import attention
13
13
  from tpu_inference.layers.jax.rope_interface import apply_rope
14
14
  from tpu_inference.logger import init_logger
15
15
  from tpu_inference.models.jax.qwen2 import Qwen2DecoderLayer
@@ -295,8 +295,7 @@ class Qwen3ForCausalLM(nnx.Module):
295
295
  "lm_head": "model.lm_head",
296
296
  })
297
297
 
298
- metadata_map = get_default_maps(self.vllm_config.model_config,
299
- self.mesh, mappings)
298
+ metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
300
299
  load_hf_weights(vllm_config=self.vllm_config,
301
300
  model=self,
302
301
  metadata_map=metadata_map,