tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (59) hide show
  1. tests/kernels/fused_moe_v1_test.py +303 -34
  2. tests/kernels/mla_v1_test.py +129 -41
  3. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  4. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
  5. tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
  6. tests/lora/test_layers.py +4 -1
  7. tests/lora/test_lora_perf.py +53 -0
  8. tests/test_envs.py +110 -12
  9. tests/test_quantization.py +3 -0
  10. tests/test_utils.py +1 -2
  11. tpu_inference/distributed/tpu_connector.py +1 -1
  12. tpu_inference/envs.py +92 -8
  13. tpu_inference/executors/ray_distributed_executor.py +5 -1
  14. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  15. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  16. tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
  17. tpu_inference/kernels/mla/v1/kernel.py +98 -120
  18. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  19. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  20. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  21. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +82 -32
  22. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +146 -85
  23. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
  24. tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
  25. tpu_inference/layers/common/attention_interface.py +7 -1
  26. tpu_inference/layers/common/sharding.py +11 -7
  27. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
  28. tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
  29. tpu_inference/layers/vllm/fused_moe.py +170 -208
  30. tpu_inference/layers/vllm/linear_common.py +43 -21
  31. tpu_inference/layers/vllm/quantization/common.py +11 -6
  32. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
  33. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
  34. tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
  35. tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
  36. tpu_inference/models/common/model_loader.py +78 -22
  37. tpu_inference/models/jax/deepseek_v3.py +185 -64
  38. tpu_inference/models/jax/gpt_oss.py +3 -3
  39. tpu_inference/models/jax/llama_eagle3.py +4 -5
  40. tpu_inference/models/jax/qwen2_5_vl.py +161 -47
  41. tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
  42. tpu_inference/models/jax/utils/weight_utils.py +203 -155
  43. tpu_inference/models/vllm/vllm_model_wrapper.py +11 -5
  44. tpu_inference/platforms/tpu_platform.py +29 -48
  45. tpu_inference/runner/compilation_manager.py +112 -46
  46. tpu_inference/runner/kv_cache.py +40 -20
  47. tpu_inference/runner/kv_cache_manager.py +40 -31
  48. tpu_inference/runner/persistent_batch_manager.py +40 -2
  49. tpu_inference/runner/structured_decoding_manager.py +2 -3
  50. tpu_inference/runner/tpu_runner.py +94 -51
  51. tpu_inference/runner/utils.py +2 -2
  52. tpu_inference/spec_decode/jax/eagle3.py +71 -22
  53. tpu_inference/utils.py +41 -14
  54. tpu_inference/worker/tpu_worker.py +43 -45
  55. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +8 -9
  56. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +59 -58
  57. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
  58. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
  59. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
@@ -486,6 +486,11 @@ class Qwen2_5_VisionTransformer(nnx.Module):
486
486
  dtype=dtype,
487
487
  rngs=rngs)
488
488
 
489
+ additional_config = getattr(vllm_config, "additional_config",
490
+ None) or {}
491
+ self.enable_dynamic_image_sizes = additional_config.get(
492
+ "enable_dynamic_image_sizes", False)
493
+
489
494
  def rotary_pos_emb_thw(self, t, h, w):
490
495
  hpos_ids, wpos_ids = jnp.indices((h, w))
491
496
  hpos_ids = hpos_ids.reshape(
@@ -579,21 +584,7 @@ class Qwen2_5_VisionTransformer(nnx.Module):
579
584
  seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
580
585
  return max_seqlen, seqlens
581
586
 
582
- def __call__(self, x: jax.Array, grid_thw: tuple[tuple[int, int,
583
- int]]) -> jax.Array:
584
- # x: pixel_values: jax.Array
585
- # """Shape:
586
- # `(num_patches, num_channels * patch_size * patch_size)`
587
- # """
588
-
589
- # grid_thw: image_grid_thw: jax.Array
590
- # """Shape: `(num_images, 3)`
591
- # This should be in `(grid_t, grid_h, grid_w)` format.
592
- # """
593
- hidden_states = self.patch_embed(x)
594
-
595
- # num of patches
596
- seq_len = x.shape[0]
587
+ def compute_aux_arrays(self, grid_thw: tuple[tuple[int, int, int]]):
597
588
  # num of images/videoes
598
589
  num_grids = len(grid_thw)
599
590
 
@@ -638,6 +629,42 @@ class Qwen2_5_VisionTransformer(nnx.Module):
638
629
  cu_seqlens = jnp.pad(cu_seqlens, ((1, 0), ),
639
630
  mode='constant',
640
631
  constant_values=0)
632
+ return window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens
633
+
634
+ def pad_inputs(self, x, window_index, rotary_pos_emb, cu_seqlens,
635
+ cu_window_seqlens):
636
+ # padding
637
+ num_patches = int(rotary_pos_emb.shape[0])
638
+ bucket_num_patches = 1 << (num_patches - 1).bit_length()
639
+ num_tokens = window_index.shape[0]
640
+ bucket_num_tokens = bucket_num_patches // self.spatial_merge_unit
641
+ vit_merger_window_size = (self.window_size //
642
+ self.spatial_merge_size // self.patch_size)
643
+ max_windows = (bucket_num_tokens // vit_merger_window_size) + 2
644
+
645
+ rotary_pos_emb = jnp.pad(rotary_pos_emb,
646
+ ((0, bucket_num_patches - num_patches),
647
+ (0, 0)))
648
+ window_index = jnp.concatenate([
649
+ window_index,
650
+ jnp.arange(num_tokens, bucket_num_tokens, dtype=jnp.int32)
651
+ ])
652
+ cu_window_seqlens = jnp.append(cu_window_seqlens, bucket_num_patches)
653
+ pad_w = max(0, max_windows + 1 - cu_window_seqlens.shape[0])
654
+ cu_window_seqlens = jnp.pad(cu_window_seqlens, (0, pad_w), mode='edge')
655
+ cu_seqlens = jnp.append(cu_seqlens, bucket_num_patches)
656
+
657
+ x_padded = jnp.pad(x, ((0, bucket_num_patches - x.shape[0]), (0, 0)))
658
+
659
+ return x_padded, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens, num_tokens
660
+
661
+ def compute_hidden_states(self, x: jax.Array, window_index: jax.Array,
662
+ rotary_pos_emb: jax.Array, cu_seqlens: jax.Array,
663
+ cu_window_seqlens: jax.Array) -> jax.Array:
664
+ hidden_states = self.patch_embed(x)
665
+
666
+ # num of patches
667
+ seq_len = x.shape[0]
641
668
 
642
669
  hidden_states = hidden_states.reshape(
643
670
  seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
@@ -664,6 +691,48 @@ class Qwen2_5_VisionTransformer(nnx.Module):
664
691
  hidden_states = hidden_states[reverse_indices, :]
665
692
  return hidden_states
666
693
 
694
+ @jax.jit
695
+ def encode_padded_jit(self, x_padded, window_index, rotary_pos_emb,
696
+ cu_seqlens, cu_window_seqlens):
697
+ return self.compute_hidden_states(x_padded, window_index,
698
+ rotary_pos_emb, cu_seqlens,
699
+ cu_window_seqlens)
700
+
701
+ @partial(
702
+ jax.jit,
703
+ static_argnames=("grid_thw", ),
704
+ )
705
+ def encode_jit(self, x, grid_thw):
706
+ window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens = self.compute_aux_arrays(
707
+ grid_thw)
708
+ return self.compute_hidden_states(x, window_index, rotary_pos_emb,
709
+ cu_seqlens, cu_window_seqlens)
710
+
711
+ def __call__(self, x: jax.Array, grid_thw: tuple[tuple[int, int,
712
+ int]]) -> jax.Array:
713
+ # x: pixel_values: jax.Array
714
+ # """Shape:
715
+ # `(num_patches, num_channels * patch_size * patch_size)`
716
+ # """
717
+
718
+ # grid_thw: image_grid_thw: jax.Array
719
+ # """Shape: `(num_images, 3)`
720
+ # This should be in `(grid_t, grid_h, grid_w)` format.
721
+ # """
722
+ if self.enable_dynamic_image_sizes:
723
+ window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens = self.compute_aux_arrays(
724
+ grid_thw)
725
+ x_padded, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens, num_tokens = self.pad_inputs(
726
+ x, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens)
727
+
728
+ hidden_states = self.encode_padded_jit(x_padded, window_index,
729
+ rotary_pos_emb, cu_seqlens,
730
+ cu_window_seqlens)
731
+ return hidden_states[:num_tokens]
732
+
733
+ else:
734
+ return self.encode_jit(x, grid_thw)
735
+
667
736
 
668
737
  class Qwen2_5_VLForConditionalGeneration(nnx.Module):
669
738
 
@@ -888,10 +957,6 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
888
957
  # "video"] = self._parse_and_validate_video_input(**kwargs)
889
958
  return mm_input_by_modality
890
959
 
891
- @partial(
892
- jax.jit,
893
- static_argnames=("image_grid_thw", ),
894
- )
895
960
  def get_single_image_embedding(self, image_pixel_values, image_grid_thw):
896
961
  return self.visual(image_pixel_values, (image_grid_thw, ))
897
962
 
@@ -1072,33 +1137,82 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
1072
1137
  self,
1073
1138
  run_compilation_fn: Callable,
1074
1139
  ) -> None:
1075
- image_shapes = []
1076
- if (warmup_config := self.vllm_config.additional_config.get(
1077
- "vision_warmup_config")):
1078
- image_shapes = warmup_config.get("image_shapes")
1079
-
1080
1140
  vc = self.vllm_config.model_config.hf_config.vision_config
1081
- factor = vc.patch_size * vc.spatial_merge_size
1082
- for input_hw in image_shapes:
1083
- if not isinstance(input_hw, list) or len(input_hw) != 2:
1084
- logger.warning(f"Skipping invalid shape {input_hw}.")
1085
- continue
1086
- h_input, w_input = input_hw
1087
- h_processed = round(h_input / factor) * factor
1088
- w_processed = round(w_input / factor) * factor
1089
- t, h, w = 1, h_processed // vc.patch_size, w_processed // vc.patch_size
1090
- grid_thw = (t, h, w)
1091
- num_patches = t * h * w
1092
- patch_input_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
1093
-
1094
- dummy_pixel_values = jnp.ones(
1095
- (num_patches, patch_input_dim),
1096
- self.vllm_config.model_config.dtype,
1097
- )
1098
- dummy_grid_thw = grid_thw
1141
+ patch_input_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
1142
+ if self.visual.enable_dynamic_image_sizes:
1143
+ spatial_merge_unit = vc.spatial_merge_size**2
1144
+ max_num_batched_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens
1145
+ mm_kwargs = self.vllm_config.model_config.multimodal_config.mm_processor_kwargs or {}
1146
+ limit_pixels = float(mm_kwargs.get("max_pixels", float('inf')))
1147
+
1148
+ max_patches = int(
1149
+ min(max_num_batched_tokens * spatial_merge_unit,
1150
+ limit_pixels / (vc.patch_size**2)))
1151
+
1152
+ num_patches_paddings = [
1153
+ 1 << i for i in range(4, (max_patches - 1).bit_length() + 1)
1154
+ ]
1155
+ rotary_dim = vc.hidden_size // vc.num_heads // 2
1156
+ vit_merger_window_size = (vc.window_size //
1157
+ vc.spatial_merge_size // vc.patch_size)
1158
+
1159
+ for num_patches in num_patches_paddings:
1160
+ dummy_x_padded = jnp.ones(
1161
+ (num_patches, patch_input_dim),
1162
+ dtype=self.vllm_config.model_config.dtype)
1163
+
1164
+ num_tokens = num_patches // spatial_merge_unit
1165
+ dummy_window_index = jnp.arange(num_tokens, dtype=jnp.int32)
1166
+
1167
+ dummy_rotary_pos_emb = jnp.ones(
1168
+ (num_patches, rotary_dim),
1169
+ dtype=self.vllm_config.model_config.dtype)
1170
+
1171
+ dummy_cu_seqlens = jnp.array([0, num_patches, num_patches],
1172
+ dtype=jnp.int32)
1173
+
1174
+ max_windows = (num_tokens // vit_merger_window_size) + 2
1175
+ patches_per_window = (vit_merger_window_size**
1176
+ 2) * spatial_merge_unit
1177
+ dummy_cu_window_seqlens = jnp.arange(
1178
+ max_windows + 1, dtype=jnp.int32) * patches_per_window
1179
+ dummy_cu_window_seqlens = jnp.minimum(dummy_cu_window_seqlens,
1180
+ num_patches)
1181
+
1182
+ run_compilation_fn("vision_encoder_padded",
1183
+ self.visual.encode_padded_jit,
1184
+ dummy_x_padded,
1185
+ dummy_window_index,
1186
+ dummy_rotary_pos_emb,
1187
+ dummy_cu_seqlens,
1188
+ dummy_cu_window_seqlens,
1189
+ num_patches=num_patches)
1190
+ else:
1191
+ image_shapes = []
1192
+ if (warmup_config := self.vllm_config.additional_config.get(
1193
+ "vision_warmup_config")):
1194
+ image_shapes = warmup_config.get("image_shapes")
1195
+
1196
+ factor = vc.patch_size * vc.spatial_merge_size
1197
+ for input_hw in image_shapes:
1198
+ if not isinstance(input_hw, list) or len(input_hw) != 2:
1199
+ logger.warning(f"Skipping invalid shape {input_hw}.")
1200
+ continue
1201
+ h_input, w_input = input_hw
1202
+ h_processed = round(h_input / factor) * factor
1203
+ w_processed = round(w_input / factor) * factor
1204
+ t, h, w = 1, h_processed // vc.patch_size, w_processed // vc.patch_size
1205
+ grid_thw = (t, h, w)
1206
+ num_patches = t * h * w
1207
+
1208
+ dummy_pixel_values = jnp.ones(
1209
+ (num_patches, patch_input_dim),
1210
+ self.vllm_config.model_config.dtype,
1211
+ )
1212
+ dummy_grid_thw = (grid_thw, )
1099
1213
 
1100
- run_compilation_fn("single_image_encoder",
1101
- self.get_single_image_embedding,
1102
- dummy_pixel_values,
1103
- dummy_grid_thw,
1104
- image_shape=input_hw)
1214
+ run_compilation_fn("vision_encoder",
1215
+ self.visual.encode_jit,
1216
+ dummy_pixel_values,
1217
+ dummy_grid_thw,
1218
+ image_shape=input_hw)
@@ -154,12 +154,9 @@ def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
154
154
  logger.info(f"Memory usage before applying quantization of params: "
155
155
  f"hbm={utils.hbm_usage_gb(jax.local_devices())}Gb")
156
156
 
157
- # TODO (jacobplatin): we should refactor this to pass a dtype (or config) directly
158
- kv_cache_jnp_dtype = utils.get_jax_dtype_from_str_dtype(kv_cache_dtype)
159
-
160
- # Handle the case where kv_cache_dtype is "auto"
161
- if kv_cache_jnp_dtype is None:
162
- assert kv_cache_dtype == "auto", "kv_cache_dtype must be 'auto' if kv_cache_jnp_dtype is None"
157
+ if kv_cache_dtype != "auto":
158
+ kv_cache_jnp_dtype = utils.to_jax_dtype(kv_cache_dtype)
159
+ else:
163
160
  kv_cache_jnp_dtype = DEFAULT_KV_CACHE_DTYPE
164
161
 
165
162
  kv_caches = create_kv_caches(
@@ -169,9 +166,11 @@ def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
169
166
  head_size=kv_cache_head_size,
170
167
  mesh=mesh,
171
168
  layer_names=[f"layer.{i}" for i in range(num_hidden_layers)],
172
- cache_dtype=kv_cache_jnp_dtype)
169
+ cache_dtype=kv_cache_jnp_dtype,
170
+ use_mla=model.vllm_config.model_config.use_mla,
171
+ )
173
172
 
174
- dp_size = mesh.shape.get("data", 1) * mesh.shape.get("attn", 1)
173
+ dp_size = model.vllm_config.sharding_config.total_dp_size
175
174
 
176
175
  # NOTE: the inputs don't need to match the actual ones, as long as the consumed weights are the same
177
176
  input_ids = jax.random.randint(rng,