tpu-inference 0.11.1.dev202511180814__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (76) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -7
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/lora/utils.py +0 -8
  9. tests/test_envs.py +110 -12
  10. tests/test_quantization.py +3 -0
  11. tests/test_utils.py +1 -2
  12. tpu_inference/__init__.py +22 -3
  13. tpu_inference/core/disagg_utils.py +6 -8
  14. tpu_inference/distributed/tpu_connector.py +3 -4
  15. tpu_inference/distributed/utils.py +3 -2
  16. tpu_inference/envs.py +93 -9
  17. tpu_inference/executors/ray_distributed_executor.py +9 -2
  18. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  19. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  20. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  21. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  22. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  23. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  25. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +140 -67
  26. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +204 -120
  27. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  28. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  29. tpu_inference/layers/common/attention_interface.py +7 -1
  30. tpu_inference/layers/common/sharding.py +11 -7
  31. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  32. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  33. tpu_inference/layers/vllm/fused_moe.py +170 -208
  34. tpu_inference/layers/vllm/linear_common.py +43 -21
  35. tpu_inference/layers/vllm/quantization/common.py +11 -6
  36. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  37. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  38. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  39. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  40. tpu_inference/layers/vllm/sharding.py +2 -2
  41. tpu_inference/lora/torch_punica_tpu.py +1 -2
  42. tpu_inference/models/common/model_loader.py +84 -28
  43. tpu_inference/models/jax/deepseek_v3.py +185 -64
  44. tpu_inference/models/jax/gpt_oss.py +3 -3
  45. tpu_inference/models/jax/llama3.py +2 -1
  46. tpu_inference/models/jax/llama_eagle3.py +8 -5
  47. tpu_inference/models/jax/llama_guard_4.py +361 -0
  48. tpu_inference/models/jax/qwen2.py +2 -1
  49. tpu_inference/models/jax/qwen2_5_vl.py +163 -48
  50. tpu_inference/models/jax/qwen3.py +2 -1
  51. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  52. tpu_inference/models/jax/utils/weight_utils.py +205 -144
  53. tpu_inference/models/vllm/vllm_model_wrapper.py +14 -8
  54. tpu_inference/platforms/tpu_platform.py +34 -50
  55. tpu_inference/runner/compilation_manager.py +144 -60
  56. tpu_inference/runner/kv_cache.py +40 -20
  57. tpu_inference/runner/kv_cache_manager.py +48 -33
  58. tpu_inference/runner/persistent_batch_manager.py +40 -2
  59. tpu_inference/runner/structured_decoding_manager.py +2 -3
  60. tpu_inference/runner/tpu_runner.py +280 -149
  61. tpu_inference/runner/utils.py +2 -2
  62. tpu_inference/spec_decode/jax/eagle3.py +71 -21
  63. tpu_inference/tpu_info.py +4 -3
  64. tpu_inference/utils.py +46 -18
  65. tpu_inference/worker/tpu_worker.py +197 -63
  66. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +9 -10
  67. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +70 -74
  68. tpu_inference/mock/__init__.py +0 -0
  69. tpu_inference/mock/vllm_config_utils.py +0 -28
  70. tpu_inference/mock/vllm_envs.py +0 -1219
  71. tpu_inference/mock/vllm_logger.py +0 -212
  72. tpu_inference/mock/vllm_logging_utils.py +0 -15
  73. tpu_inference/models/jax/phi3.py +0 -376
  74. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  75. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  76. {tpu_inference-0.11.1.dev202511180814.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,361 @@
1
+ import re
2
+ from typing import Any, List, Optional, Tuple
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import torch
7
+ from flax import nnx
8
+ from flax.typing import PRNGKey
9
+ from jax.sharding import Mesh
10
+ from jax.sharding import PartitionSpec as P
11
+ from vllm.config import VllmConfig
12
+
13
+ from tpu_inference.layers.jax.attention.attention import AttentionMetadata
14
+ from tpu_inference.layers.jax.attention.llama4_attention import Llama4Attention
15
+ from tpu_inference.layers.jax.constants import KVCacheType
16
+ from tpu_inference.layers.jax.layers import DenseFFW, Embedder, LMhead, RMSNorm
17
+ from tpu_inference.layers.jax.misc import shard_put
18
+ from tpu_inference.layers.jax.transformer_block import TransformerBlock
19
+ from tpu_inference.logger import init_logger
20
+ from tpu_inference.models.jax.utils.weight_utils import (
21
+ get_param, model_weights_generator, print_param_info, reshape_params,
22
+ transpose_params)
23
+
24
+ logger = init_logger(__name__)
25
+
26
+
27
+ class LlamaGuard4ForCausalLM(nnx.Module):
28
+
29
+ def __init__(self,
30
+ vllm_config: VllmConfig,
31
+ rng: PRNGKey,
32
+ mesh: Mesh,
33
+ force_random_weights: bool = False):
34
+ logger.warning(
35
+ "🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨\n"
36
+ "Llama Guard 4 (JAX) is WIP: Only the text modality is currently implemented. "
37
+ "Multimodal inputs will fail.\n"
38
+ "🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨 🚨🚨🚨WARNING🚨🚨🚨")
39
+ assert mesh is not None
40
+
41
+ self.vllm_config = vllm_config
42
+ self.vllm_config.model_config.dtype = torch.bfloat16
43
+ model_config = vllm_config.model_config
44
+ text_config = model_config.hf_config.text_config
45
+
46
+ self.mesh = mesh
47
+ self.is_verbose = getattr(self.vllm_config.additional_config,
48
+ "is_verbose", False)
49
+
50
+ self.use_qk_norm = getattr(text_config, "use_qk_norm", True)
51
+
52
+ vocab_size = model_config.get_vocab_size()
53
+ self.hidden_size = model_config.get_hidden_size()
54
+
55
+ self.dtype: jnp.dtype = jnp.bfloat16
56
+
57
+ self.num_layers: int = getattr(text_config, "num_layers", 48)
58
+ hidden_act: str = getattr(text_config, "hidden_act", "silu")
59
+
60
+ rms_norm_eps = getattr(text_config, "rms_norm_eps", 1e-5)
61
+ self.num_attention_heads = getattr(text_config, "num_attention_heads",
62
+ 40)
63
+ self.num_key_value_heads = getattr(text_config, "num_key_value_heads",
64
+ 8)
65
+ self.head_dim = getattr(text_config, "head_dim", 128)
66
+
67
+ intermediate_size = getattr(text_config, "intermediate_size", 8192)
68
+
69
+ self.rope_theta_text = getattr(text_config, "rope_theta", 500000.0)
70
+ self.rope_scaling = getattr(text_config, "rope_scaling")
71
+
72
+ self.rng = nnx.Rngs(rng)
73
+
74
+ self.embedder = Embedder(
75
+ vocab_size=vocab_size,
76
+ hidden_size=self.hidden_size,
77
+ dtype=self.dtype,
78
+ vd_sharding=(('data', 'model'), None),
79
+ rngs=self.rng,
80
+ random_init=force_random_weights,
81
+ )
82
+
83
+ self.layers = []
84
+
85
+ for i in range(self.num_layers):
86
+ use_attention_rope = True
87
+
88
+ custom_module = DenseFFW(dtype=self.dtype,
89
+ hidden_act=hidden_act,
90
+ hidden_size=self.hidden_size,
91
+ intermediate_size=intermediate_size,
92
+ random_init=force_random_weights,
93
+ rngs=self.rng,
94
+ df_sharding=P(None, 'model'),
95
+ fd_sharding=P('model', None),
96
+ activation_ffw_td=P('data', None))
97
+
98
+ attn = Llama4Attention(
99
+ hidden_size=self.hidden_size,
100
+ dtype=self.dtype,
101
+ num_attention_heads=self.num_attention_heads,
102
+ num_key_value_heads=self.num_key_value_heads,
103
+ head_dim=self.head_dim,
104
+ rope_theta=self.rope_theta_text,
105
+ rope_scaling={
106
+ "scale_factor":
107
+ self.rope_scaling["factor"],
108
+ "low_freq_factor":
109
+ self.rope_scaling["low_freq_factor"],
110
+ "high_freq_factor":
111
+ self.rope_scaling["high_freq_factor"],
112
+ "original_max_position_embeddings":
113
+ self.rope_scaling["original_max_position_embeddings"]
114
+ },
115
+ rngs=self.rng,
116
+ rope_input_ordering="interleaved",
117
+ # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
118
+ kv_cache_dtype=vllm_config.cache_config.cache_dtype,
119
+ temperature_tuning=True,
120
+ temperature_tuning_scale=0.1,
121
+ temperature_tuning_floor_scale=8192,
122
+ use_qk_norm=self.use_qk_norm,
123
+ attention_chunk_size=None if use_attention_rope else 8192,
124
+ mesh=self.mesh,
125
+ random_init=force_random_weights,
126
+ activation_attention_td=('data', 'model'),
127
+ activation_q_td=('data', 'model'),
128
+ query_tnh=P('data', 'model', None),
129
+ keyvalue_skh=P('data', 'model', None),
130
+ activation_attention_out_td=('data', 'model'),
131
+ attn_o_tnh=P('data', 'model', None),
132
+ dnh_sharding=(None, 'model', None),
133
+ dkh_sharding=(None, 'model', None),
134
+ nhd_sharding=('model', None, None),
135
+ )
136
+
137
+ pre_attention_norm = RMSNorm(
138
+ dims=self.hidden_size,
139
+ random_init=force_random_weights,
140
+ epsilon=rms_norm_eps,
141
+ rngs=self.rng,
142
+ activation_ffw_td=('data', None),
143
+ with_scale=True,
144
+ dtype=self.dtype,
145
+ )
146
+
147
+ pre_mlp_norm = RMSNorm(
148
+ dims=self.hidden_size,
149
+ activation_ffw_td=('data', None),
150
+ epsilon=rms_norm_eps,
151
+ rngs=self.rng,
152
+ with_scale=True,
153
+ dtype=self.dtype,
154
+ random_init=force_random_weights,
155
+ )
156
+
157
+ block = TransformerBlock(custom_module=custom_module,
158
+ attn=attn,
159
+ pre_attention_norm=pre_attention_norm,
160
+ pre_mlp_norm=pre_mlp_norm,
161
+ use_attention_rope=use_attention_rope)
162
+ self.layers.append(block)
163
+
164
+ self.final_norm = RMSNorm(
165
+ dims=self.hidden_size,
166
+ activation_ffw_td=P(),
167
+ epsilon=rms_norm_eps,
168
+ rngs=self.rng,
169
+ with_scale=True,
170
+ dtype=self.dtype,
171
+ random_init=force_random_weights,
172
+ )
173
+
174
+ self.lm_head = LMhead(vocab_size=vocab_size,
175
+ hidden_size=self.hidden_size,
176
+ dtype=self.dtype,
177
+ rngs=self.rng,
178
+ vd_sharding=(('data', 'model'), None),
179
+ dv_sharding=(None, ('data', 'model')),
180
+ random_init=force_random_weights)
181
+ if self.is_verbose:
182
+ self._print_model_architecture()
183
+
184
+ def _print_model_architecture(self):
185
+
186
+ logger.info("### Embedding ###")
187
+ nnx.display(self.embedder)
188
+
189
+ logger.info("\n### Layers ###")
190
+ for i, layer in enumerate(self.layers):
191
+ logger.info(f"\n--- Layer {i} ---")
192
+ nnx.display(layer)
193
+
194
+ logger.info("\n### LM Head ###")
195
+ nnx.display(self.lm_head)
196
+
197
+ def load_weights(self, rng: jax.Array, cache_dir: Optional[str] = None):
198
+ self.rng = nnx.Rngs(rng)
199
+
200
+ weight_loader = LlamaGuard4WeightLoader(
201
+ vllm_config=self.vllm_config,
202
+ hidden_size=self.hidden_size,
203
+ attn_heads=self.num_attention_heads,
204
+ num_key_value_heads=self.num_key_value_heads,
205
+ attn_head_dim=self.head_dim)
206
+ weight_loader.load_weights(self)
207
+
208
+ def __call__(
209
+ self,
210
+ kv_caches: List[jax.Array],
211
+ input_ids: jax.Array,
212
+ attention_metadata: AttentionMetadata,
213
+ inputs_embeds: Optional[jax.Array] = None,
214
+ layer_metadata_tuple: Optional[Tuple] = None,
215
+ lora_metadata: Optional[Any] = None,
216
+ *args,
217
+ ) -> Tuple[List[KVCacheType], jax.Array]:
218
+ is_prefill = False
219
+
220
+ if inputs_embeds is not None:
221
+ x_TD = inputs_embeds
222
+ elif input_ids is not None:
223
+ x_TD = self.embedder.encode(input_ids)
224
+ else:
225
+ raise ValueError(
226
+ "Cannot run forward pass: Both input_ids and inputs_embeds are None."
227
+ )
228
+
229
+ for (i, block) in enumerate(self.layers):
230
+ kv_cache = kv_caches[i]
231
+ new_kv_cache, x_TD = block(x_TD, is_prefill, kv_cache,
232
+ attention_metadata)
233
+ jax.block_until_ready(x_TD)
234
+ kv_caches[i] = new_kv_cache
235
+
236
+ final_activation_TD = self.final_norm(x_TD)
237
+
238
+ return kv_caches, final_activation_TD, []
239
+
240
+ def compute_logits(self, hidden_states: jax.Array) -> jax.Array:
241
+ logits_TV = jnp.dot(hidden_states,
242
+ self.lm_head.input_embedding_table_DV.value)
243
+ return logits_TV
244
+
245
+ def get_input_embeddings(
246
+ self,
247
+ input_ids: jax.Array,
248
+ multimodal_embeddings: Optional[List[jax.Array]] = None
249
+ ) -> jax.Array:
250
+ """
251
+ Computes the embeddings for text input (used for input to fusion).
252
+ """
253
+ return self.embedder.encode(input_ids)
254
+
255
+
256
+ class LlamaGuard4WeightLoader:
257
+
258
+ def __init__(self, vllm_config: VllmConfig, hidden_size, attn_heads,
259
+ num_key_value_heads, attn_head_dim):
260
+ self.names_and_weights_generator = model_weights_generator(
261
+ model_name_or_path=vllm_config.model_config.model,
262
+ framework="flax",
263
+ filter_regex="language_model",
264
+ download_dir=vllm_config.load_config.download_dir)
265
+ self.is_verbose = getattr(vllm_config.additional_config, "is_verbose",
266
+ False)
267
+ self._transpose_map = {
268
+ "q_proj": (2, 0, 1),
269
+ "k_proj": (2, 0, 1),
270
+ "v_proj": (2, 0, 1),
271
+ "o_proj": (1, 2, 0),
272
+ "lm_head": (1, 0),
273
+ "feed_forward.down_proj": (1, 0),
274
+ "feed_forward.gate_proj": (1, 0),
275
+ "feed_forward.up_proj": (1, 0),
276
+ "mlp.down_proj": (1, 0),
277
+ "mlp.gate_proj": (1, 0),
278
+ "mlp.up_proj": (1, 0),
279
+ }
280
+ self._weight_shape_map = {
281
+ "q_proj": (attn_heads, attn_head_dim, hidden_size),
282
+ "k_proj": (num_key_value_heads, attn_head_dim, hidden_size),
283
+ "v_proj": (num_key_value_heads, attn_head_dim, hidden_size),
284
+ "o_proj": (hidden_size, attn_heads, attn_head_dim),
285
+ }
286
+
287
+ self._loaded_to_standardized_keys = {
288
+ "language_model.model.embed_tokens.weight":
289
+ "embedder.input_embedding_table_VD",
290
+ "language_model.lm_head.weight":
291
+ "lm_head.input_embedding_table_DV",
292
+ "language_model.model.norm.weight":
293
+ "final_norm.scale",
294
+ "language_model.model.layers.*.input_layernorm.weight":
295
+ "layers.*.pre_attention_norm.scale",
296
+ "language_model.model.layers.*.post_attention_layernorm.weight":
297
+ "layers.*.pre_mlp_norm.scale",
298
+ "language_model.model.layers.*.self_attn.q_proj.weight":
299
+ "layers.*.attn.kernel_q_proj_DNH",
300
+ "language_model.model.layers.*.self_attn.k_proj.weight":
301
+ "layers.*.attn.kernel_k_proj_DKH",
302
+ "language_model.model.layers.*.self_attn.v_proj.weight":
303
+ "layers.*.attn.kernel_v_proj_DKH",
304
+ "language_model.model.layers.*.self_attn.o_proj.weight":
305
+ "layers.*.attn.kernel_o_proj_NHD",
306
+ "language_model.model.layers.*.feed_forward.gate_proj.weight":
307
+ "layers.*.custom_module.kernel_gating_DF",
308
+ "language_model.model.layers.*.feed_forward.up_proj.weight":
309
+ "layers.*.custom_module.kernel_up_proj_DF",
310
+ "language_model.model.layers.*.feed_forward.down_proj.weight":
311
+ "layers.*.custom_module.kernel_down_proj_FD",
312
+ }
313
+
314
+ def map_loaded_to_standardized_name(self, loaded_key: str) -> str:
315
+ if "layer" in loaded_key:
316
+ layer_num = re.search(r"layers\.(\d+)", loaded_key).group(1)
317
+ layer_key = re.sub(r"layers\.\d+", "layers.*", loaded_key)
318
+ mapped_key = self._loaded_to_standardized_keys.get(
319
+ layer_key, loaded_key)
320
+ mapped_key = re.sub(r"layers\.\*", f"layers.{layer_num}",
321
+ mapped_key)
322
+ else:
323
+ mapped_key = self._loaded_to_standardized_keys.get(
324
+ loaded_key, loaded_key)
325
+ return mapped_key
326
+
327
+ def load_weights(self, model_for_loading: nnx.Module):
328
+ model_params = nnx.state(model_for_loading)
329
+ with jax.default_device(jax.devices("cpu")[0]):
330
+ for loaded_name, loaded_weight in self.names_and_weights_generator:
331
+ if loaded_name.endswith(".bias"):
332
+ continue
333
+ if "vision_model" in loaded_name or "multi_modal_projector" in loaded_name:
334
+ continue
335
+
336
+ mapped_name = self.map_loaded_to_standardized_name(loaded_name)
337
+ model_weight = get_param(model_params, mapped_name)
338
+
339
+ if not loaded_name.endswith(".bias"):
340
+ # For other layers, continue to use the transpose_params helper.
341
+ loaded_weight = reshape_params(loaded_name, loaded_weight,
342
+ self._weight_shape_map)
343
+ loaded_weight = transpose_params(loaded_name,
344
+ loaded_weight,
345
+ self._transpose_map)
346
+ if model_weight.value.shape != loaded_weight.shape:
347
+ raise ValueError(
348
+ f"Loaded shape for {loaded_name}: {loaded_weight.shape} "
349
+ f"does not match model shape for {mapped_name}: {model_weight.value.shape}!"
350
+ )
351
+ logger.debug(
352
+ f"Transformed parameter {loaded_name} to {mapped_name}: {loaded_weight.shape} --> {model_weight.value.shape}"
353
+ )
354
+
355
+ model_weight.value = shard_put(loaded_weight,
356
+ model_weight.sharding,
357
+ mesh=model_for_loading.mesh)
358
+ if self.is_verbose:
359
+ print_param_info(model_weight, loaded_name)
360
+
361
+ nnx.update(model_for_loading, model_params)
@@ -368,7 +368,8 @@ class Qwen2ForCausalLM(nnx.Module):
368
368
  "lm_head": "model.lm_head",
369
369
  })
370
370
 
371
- metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
371
+ metadata_map = get_default_maps(self.vllm_config.model_config,
372
+ self.mesh, mappings)
372
373
  load_hf_weights(vllm_config=self.vllm_config,
373
374
  model=self,
374
375
  metadata_map=metadata_map,
@@ -486,6 +486,11 @@ class Qwen2_5_VisionTransformer(nnx.Module):
486
486
  dtype=dtype,
487
487
  rngs=rngs)
488
488
 
489
+ additional_config = getattr(vllm_config, "additional_config",
490
+ None) or {}
491
+ self.enable_dynamic_image_sizes = additional_config.get(
492
+ "enable_dynamic_image_sizes", False)
493
+
489
494
  def rotary_pos_emb_thw(self, t, h, w):
490
495
  hpos_ids, wpos_ids = jnp.indices((h, w))
491
496
  hpos_ids = hpos_ids.reshape(
@@ -579,21 +584,7 @@ class Qwen2_5_VisionTransformer(nnx.Module):
579
584
  seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
580
585
  return max_seqlen, seqlens
581
586
 
582
- def __call__(self, x: jax.Array, grid_thw: tuple[tuple[int, int,
583
- int]]) -> jax.Array:
584
- # x: pixel_values: jax.Array
585
- # """Shape:
586
- # `(num_patches, num_channels * patch_size * patch_size)`
587
- # """
588
-
589
- # grid_thw: image_grid_thw: jax.Array
590
- # """Shape: `(num_images, 3)`
591
- # This should be in `(grid_t, grid_h, grid_w)` format.
592
- # """
593
- hidden_states = self.patch_embed(x)
594
-
595
- # num of patches
596
- seq_len = x.shape[0]
587
+ def compute_aux_arrays(self, grid_thw: tuple[tuple[int, int, int]]):
597
588
  # num of images/videoes
598
589
  num_grids = len(grid_thw)
599
590
 
@@ -638,6 +629,42 @@ class Qwen2_5_VisionTransformer(nnx.Module):
638
629
  cu_seqlens = jnp.pad(cu_seqlens, ((1, 0), ),
639
630
  mode='constant',
640
631
  constant_values=0)
632
+ return window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens
633
+
634
+ def pad_inputs(self, x, window_index, rotary_pos_emb, cu_seqlens,
635
+ cu_window_seqlens):
636
+ # padding
637
+ num_patches = int(rotary_pos_emb.shape[0])
638
+ bucket_num_patches = 1 << (num_patches - 1).bit_length()
639
+ num_tokens = window_index.shape[0]
640
+ bucket_num_tokens = bucket_num_patches // self.spatial_merge_unit
641
+ vit_merger_window_size = (self.window_size //
642
+ self.spatial_merge_size // self.patch_size)
643
+ max_windows = (bucket_num_tokens // vit_merger_window_size) + 2
644
+
645
+ rotary_pos_emb = jnp.pad(rotary_pos_emb,
646
+ ((0, bucket_num_patches - num_patches),
647
+ (0, 0)))
648
+ window_index = jnp.concatenate([
649
+ window_index,
650
+ jnp.arange(num_tokens, bucket_num_tokens, dtype=jnp.int32)
651
+ ])
652
+ cu_window_seqlens = jnp.append(cu_window_seqlens, bucket_num_patches)
653
+ pad_w = max(0, max_windows + 1 - cu_window_seqlens.shape[0])
654
+ cu_window_seqlens = jnp.pad(cu_window_seqlens, (0, pad_w), mode='edge')
655
+ cu_seqlens = jnp.append(cu_seqlens, bucket_num_patches)
656
+
657
+ x_padded = jnp.pad(x, ((0, bucket_num_patches - x.shape[0]), (0, 0)))
658
+
659
+ return x_padded, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens, num_tokens
660
+
661
+ def compute_hidden_states(self, x: jax.Array, window_index: jax.Array,
662
+ rotary_pos_emb: jax.Array, cu_seqlens: jax.Array,
663
+ cu_window_seqlens: jax.Array) -> jax.Array:
664
+ hidden_states = self.patch_embed(x)
665
+
666
+ # num of patches
667
+ seq_len = x.shape[0]
641
668
 
642
669
  hidden_states = hidden_states.reshape(
643
670
  seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
@@ -664,6 +691,48 @@ class Qwen2_5_VisionTransformer(nnx.Module):
664
691
  hidden_states = hidden_states[reverse_indices, :]
665
692
  return hidden_states
666
693
 
694
+ @jax.jit
695
+ def encode_padded_jit(self, x_padded, window_index, rotary_pos_emb,
696
+ cu_seqlens, cu_window_seqlens):
697
+ return self.compute_hidden_states(x_padded, window_index,
698
+ rotary_pos_emb, cu_seqlens,
699
+ cu_window_seqlens)
700
+
701
+ @partial(
702
+ jax.jit,
703
+ static_argnames=("grid_thw", ),
704
+ )
705
+ def encode_jit(self, x, grid_thw):
706
+ window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens = self.compute_aux_arrays(
707
+ grid_thw)
708
+ return self.compute_hidden_states(x, window_index, rotary_pos_emb,
709
+ cu_seqlens, cu_window_seqlens)
710
+
711
+ def __call__(self, x: jax.Array, grid_thw: tuple[tuple[int, int,
712
+ int]]) -> jax.Array:
713
+ # x: pixel_values: jax.Array
714
+ # """Shape:
715
+ # `(num_patches, num_channels * patch_size * patch_size)`
716
+ # """
717
+
718
+ # grid_thw: image_grid_thw: jax.Array
719
+ # """Shape: `(num_images, 3)`
720
+ # This should be in `(grid_t, grid_h, grid_w)` format.
721
+ # """
722
+ if self.enable_dynamic_image_sizes:
723
+ window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens = self.compute_aux_arrays(
724
+ grid_thw)
725
+ x_padded, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens, num_tokens = self.pad_inputs(
726
+ x, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens)
727
+
728
+ hidden_states = self.encode_padded_jit(x_padded, window_index,
729
+ rotary_pos_emb, cu_seqlens,
730
+ cu_window_seqlens)
731
+ return hidden_states[:num_tokens]
732
+
733
+ else:
734
+ return self.encode_jit(x, grid_thw)
735
+
667
736
 
668
737
  class Qwen2_5_VLForConditionalGeneration(nnx.Module):
669
738
 
@@ -888,10 +957,6 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
888
957
  # "video"] = self._parse_and_validate_video_input(**kwargs)
889
958
  return mm_input_by_modality
890
959
 
891
- @partial(
892
- jax.jit,
893
- static_argnames=("image_grid_thw", ),
894
- )
895
960
  def get_single_image_embedding(self, image_pixel_values, image_grid_thw):
896
961
  return self.visual(image_pixel_values, (image_grid_thw, ))
897
962
 
@@ -1061,7 +1126,8 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
1061
1126
  "lm_head": "language_model.model.lm_head",
1062
1127
  })
1063
1128
 
1064
- metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
1129
+ metadata_map = get_default_maps(self.vllm_config.model_config,
1130
+ self.mesh, mappings)
1065
1131
  load_hf_weights(vllm_config=self.vllm_config,
1066
1132
  model=self,
1067
1133
  metadata_map=metadata_map,
@@ -1071,33 +1137,82 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
1071
1137
  self,
1072
1138
  run_compilation_fn: Callable,
1073
1139
  ) -> None:
1074
- image_shapes = []
1075
- if (warmup_config := self.vllm_config.additional_config.get(
1076
- "vision_warmup_config")):
1077
- image_shapes = warmup_config.get("image_shapes")
1078
-
1079
1140
  vc = self.vllm_config.model_config.hf_config.vision_config
1080
- factor = vc.patch_size * vc.spatial_merge_size
1081
- for input_hw in image_shapes:
1082
- if not isinstance(input_hw, list) or len(input_hw) != 2:
1083
- logger.warning(f"Skipping invalid shape {input_hw}.")
1084
- continue
1085
- h_input, w_input = input_hw
1086
- h_processed = round(h_input / factor) * factor
1087
- w_processed = round(w_input / factor) * factor
1088
- t, h, w = 1, h_processed // vc.patch_size, w_processed // vc.patch_size
1089
- grid_thw = (t, h, w)
1090
- num_patches = t * h * w
1091
- patch_input_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
1092
-
1093
- dummy_pixel_values = jnp.ones(
1094
- (num_patches, patch_input_dim),
1095
- self.vllm_config.model_config.dtype,
1096
- )
1097
- dummy_grid_thw = grid_thw
1141
+ patch_input_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
1142
+ if self.visual.enable_dynamic_image_sizes:
1143
+ spatial_merge_unit = vc.spatial_merge_size**2
1144
+ max_num_batched_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens
1145
+ mm_kwargs = self.vllm_config.model_config.multimodal_config.mm_processor_kwargs or {}
1146
+ limit_pixels = float(mm_kwargs.get("max_pixels", float('inf')))
1147
+
1148
+ max_patches = int(
1149
+ min(max_num_batched_tokens * spatial_merge_unit,
1150
+ limit_pixels / (vc.patch_size**2)))
1151
+
1152
+ num_patches_paddings = [
1153
+ 1 << i for i in range(4, (max_patches - 1).bit_length() + 1)
1154
+ ]
1155
+ rotary_dim = vc.hidden_size // vc.num_heads // 2
1156
+ vit_merger_window_size = (vc.window_size //
1157
+ vc.spatial_merge_size // vc.patch_size)
1158
+
1159
+ for num_patches in num_patches_paddings:
1160
+ dummy_x_padded = jnp.ones(
1161
+ (num_patches, patch_input_dim),
1162
+ dtype=self.vllm_config.model_config.dtype)
1163
+
1164
+ num_tokens = num_patches // spatial_merge_unit
1165
+ dummy_window_index = jnp.arange(num_tokens, dtype=jnp.int32)
1166
+
1167
+ dummy_rotary_pos_emb = jnp.ones(
1168
+ (num_patches, rotary_dim),
1169
+ dtype=self.vllm_config.model_config.dtype)
1170
+
1171
+ dummy_cu_seqlens = jnp.array([0, num_patches, num_patches],
1172
+ dtype=jnp.int32)
1173
+
1174
+ max_windows = (num_tokens // vit_merger_window_size) + 2
1175
+ patches_per_window = (vit_merger_window_size**
1176
+ 2) * spatial_merge_unit
1177
+ dummy_cu_window_seqlens = jnp.arange(
1178
+ max_windows + 1, dtype=jnp.int32) * patches_per_window
1179
+ dummy_cu_window_seqlens = jnp.minimum(dummy_cu_window_seqlens,
1180
+ num_patches)
1181
+
1182
+ run_compilation_fn("vision_encoder_padded",
1183
+ self.visual.encode_padded_jit,
1184
+ dummy_x_padded,
1185
+ dummy_window_index,
1186
+ dummy_rotary_pos_emb,
1187
+ dummy_cu_seqlens,
1188
+ dummy_cu_window_seqlens,
1189
+ num_patches=num_patches)
1190
+ else:
1191
+ image_shapes = []
1192
+ if (warmup_config := self.vllm_config.additional_config.get(
1193
+ "vision_warmup_config")):
1194
+ image_shapes = warmup_config.get("image_shapes")
1195
+
1196
+ factor = vc.patch_size * vc.spatial_merge_size
1197
+ for input_hw in image_shapes:
1198
+ if not isinstance(input_hw, list) or len(input_hw) != 2:
1199
+ logger.warning(f"Skipping invalid shape {input_hw}.")
1200
+ continue
1201
+ h_input, w_input = input_hw
1202
+ h_processed = round(h_input / factor) * factor
1203
+ w_processed = round(w_input / factor) * factor
1204
+ t, h, w = 1, h_processed // vc.patch_size, w_processed // vc.patch_size
1205
+ grid_thw = (t, h, w)
1206
+ num_patches = t * h * w
1207
+
1208
+ dummy_pixel_values = jnp.ones(
1209
+ (num_patches, patch_input_dim),
1210
+ self.vllm_config.model_config.dtype,
1211
+ )
1212
+ dummy_grid_thw = (grid_thw, )
1098
1213
 
1099
- run_compilation_fn("single_image_encoder",
1100
- self.get_single_image_embedding,
1101
- dummy_pixel_values,
1102
- dummy_grid_thw,
1103
- image_shape=input_hw)
1214
+ run_compilation_fn("vision_encoder",
1215
+ self.visual.encode_jit,
1216
+ dummy_pixel_values,
1217
+ dummy_grid_thw,
1218
+ image_shape=input_hw)
@@ -295,7 +295,8 @@ class Qwen3ForCausalLM(nnx.Module):
295
295
  "lm_head": "model.lm_head",
296
296
  })
297
297
 
298
- metadata_map = get_default_maps(self.vllm_config, self.mesh, mappings)
298
+ metadata_map = get_default_maps(self.vllm_config.model_config,
299
+ self.mesh, mappings)
299
300
  load_hf_weights(vllm_config=self.vllm_config,
300
301
  model=self,
301
302
  metadata_map=metadata_map,
@@ -154,12 +154,9 @@ def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
154
154
  logger.info(f"Memory usage before applying quantization of params: "
155
155
  f"hbm={utils.hbm_usage_gb(jax.local_devices())}Gb")
156
156
 
157
- # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
158
- kv_cache_jnp_dtype = utils.get_jax_dtype_from_str_dtype(kv_cache_dtype)
159
-
160
- # Handle the case where kv_cache_dtype is "auto"
161
- if kv_cache_jnp_dtype is None:
162
- assert kv_cache_dtype == "auto", "kv_cache_dtype must be 'auto' if kv_cache_jnp_dtype is None"
157
+ if kv_cache_dtype != "auto":
158
+ kv_cache_jnp_dtype = utils.to_jax_dtype(kv_cache_dtype)
159
+ else:
163
160
  kv_cache_jnp_dtype = DEFAULT_KV_CACHE_DTYPE
164
161
 
165
162
  kv_caches = create_kv_caches(
@@ -169,9 +166,11 @@ def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
169
166
  head_size=kv_cache_head_size,
170
167
  mesh=mesh,
171
168
  layer_names=[f"layer.{i}" for i in range(num_hidden_layers)],
172
- cache_dtype=kv_cache_jnp_dtype)
169
+ cache_dtype=kv_cache_jnp_dtype,
170
+ use_mla=model.vllm_config.model_config.use_mla,
171
+ )
173
172
 
174
- dp_size = mesh.shape.get("data", 1) * mesh.shape.get("attn", 1)
173
+ dp_size = model.vllm_config.sharding_config.total_dp_size
175
174
 
176
175
  # NOTE: the inputs don't need to match the actual ones, as long as the consumed weights are the same
177
176
  input_ids = jax.random.randint(rng,