tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.12.0.dev20251213__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/kernels/fused_moe_v1_test.py +303 -34
- tests/kernels/mla_v1_test.py +129 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +3 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +3 -1
- tests/lora/test_layers.py +4 -1
- tests/lora/test_lora_perf.py +53 -0
- tests/test_envs.py +110 -12
- tests/test_quantization.py +3 -0
- tests/test_utils.py +1 -2
- tpu_inference/distributed/tpu_connector.py +1 -1
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/ray_distributed_executor.py +5 -1
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/fused_moe/v1/kernel.py +712 -143
- tpu_inference/kernels/mla/v1/kernel.py +98 -120
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +82 -32
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +146 -85
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +2 -1
- tpu_inference/layers/common/attention_interface.py +7 -1
- tpu_inference/layers/common/sharding.py +11 -7
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +232 -64
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +5 -5
- tpu_inference/layers/vllm/fused_moe.py +170 -208
- tpu_inference/layers/vllm/linear_common.py +43 -21
- tpu_inference/layers/vllm/quantization/common.py +11 -6
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +4 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +74 -65
- tpu_inference/layers/vllm/quantization/mxfp4.py +140 -94
- tpu_inference/layers/vllm/quantization/unquantized.py +103 -80
- tpu_inference/models/common/model_loader.py +78 -22
- tpu_inference/models/jax/deepseek_v3.py +185 -64
- tpu_inference/models/jax/gpt_oss.py +3 -3
- tpu_inference/models/jax/llama_eagle3.py +4 -5
- tpu_inference/models/jax/qwen2_5_vl.py +161 -47
- tpu_inference/models/jax/utils/quantization/quantization_utils.py +7 -8
- tpu_inference/models/jax/utils/weight_utils.py +203 -155
- tpu_inference/models/vllm/vllm_model_wrapper.py +11 -5
- tpu_inference/platforms/tpu_platform.py +29 -48
- tpu_inference/runner/compilation_manager.py +112 -46
- tpu_inference/runner/kv_cache.py +40 -20
- tpu_inference/runner/kv_cache_manager.py +40 -31
- tpu_inference/runner/persistent_batch_manager.py +40 -2
- tpu_inference/runner/structured_decoding_manager.py +2 -3
- tpu_inference/runner/tpu_runner.py +94 -51
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/jax/eagle3.py +71 -22
- tpu_inference/utils.py +41 -14
- tpu_inference/worker/tpu_worker.py +43 -45
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/METADATA +8 -9
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/RECORD +59 -58
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.12.0.dev20251213.dist-info}/top_level.txt +0 -0
|
@@ -486,6 +486,11 @@ class Qwen2_5_VisionTransformer(nnx.Module):
|
|
|
486
486
|
dtype=dtype,
|
|
487
487
|
rngs=rngs)
|
|
488
488
|
|
|
489
|
+
additional_config = getattr(vllm_config, "additional_config",
|
|
490
|
+
None) or {}
|
|
491
|
+
self.enable_dynamic_image_sizes = additional_config.get(
|
|
492
|
+
"enable_dynamic_image_sizes", False)
|
|
493
|
+
|
|
489
494
|
def rotary_pos_emb_thw(self, t, h, w):
|
|
490
495
|
hpos_ids, wpos_ids = jnp.indices((h, w))
|
|
491
496
|
hpos_ids = hpos_ids.reshape(
|
|
@@ -579,21 +584,7 @@ class Qwen2_5_VisionTransformer(nnx.Module):
|
|
|
579
584
|
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
|
580
585
|
return max_seqlen, seqlens
|
|
581
586
|
|
|
582
|
-
def
|
|
583
|
-
int]]) -> jax.Array:
|
|
584
|
-
# x: pixel_values: jax.Array
|
|
585
|
-
# """Shape:
|
|
586
|
-
# `(num_patches, num_channels * patch_size * patch_size)`
|
|
587
|
-
# """
|
|
588
|
-
|
|
589
|
-
# grid_thw: image_grid_thw: jax.Array
|
|
590
|
-
# """Shape: `(num_images, 3)`
|
|
591
|
-
# This should be in `(grid_t, grid_h, grid_w)` format.
|
|
592
|
-
# """
|
|
593
|
-
hidden_states = self.patch_embed(x)
|
|
594
|
-
|
|
595
|
-
# num of patches
|
|
596
|
-
seq_len = x.shape[0]
|
|
587
|
+
def compute_aux_arrays(self, grid_thw: tuple[tuple[int, int, int]]):
|
|
597
588
|
# num of images/videoes
|
|
598
589
|
num_grids = len(grid_thw)
|
|
599
590
|
|
|
@@ -638,6 +629,42 @@ class Qwen2_5_VisionTransformer(nnx.Module):
|
|
|
638
629
|
cu_seqlens = jnp.pad(cu_seqlens, ((1, 0), ),
|
|
639
630
|
mode='constant',
|
|
640
631
|
constant_values=0)
|
|
632
|
+
return window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens
|
|
633
|
+
|
|
634
|
+
def pad_inputs(self, x, window_index, rotary_pos_emb, cu_seqlens,
|
|
635
|
+
cu_window_seqlens):
|
|
636
|
+
# padding
|
|
637
|
+
num_patches = int(rotary_pos_emb.shape[0])
|
|
638
|
+
bucket_num_patches = 1 << (num_patches - 1).bit_length()
|
|
639
|
+
num_tokens = window_index.shape[0]
|
|
640
|
+
bucket_num_tokens = bucket_num_patches // self.spatial_merge_unit
|
|
641
|
+
vit_merger_window_size = (self.window_size //
|
|
642
|
+
self.spatial_merge_size // self.patch_size)
|
|
643
|
+
max_windows = (bucket_num_tokens // vit_merger_window_size) + 2
|
|
644
|
+
|
|
645
|
+
rotary_pos_emb = jnp.pad(rotary_pos_emb,
|
|
646
|
+
((0, bucket_num_patches - num_patches),
|
|
647
|
+
(0, 0)))
|
|
648
|
+
window_index = jnp.concatenate([
|
|
649
|
+
window_index,
|
|
650
|
+
jnp.arange(num_tokens, bucket_num_tokens, dtype=jnp.int32)
|
|
651
|
+
])
|
|
652
|
+
cu_window_seqlens = jnp.append(cu_window_seqlens, bucket_num_patches)
|
|
653
|
+
pad_w = max(0, max_windows + 1 - cu_window_seqlens.shape[0])
|
|
654
|
+
cu_window_seqlens = jnp.pad(cu_window_seqlens, (0, pad_w), mode='edge')
|
|
655
|
+
cu_seqlens = jnp.append(cu_seqlens, bucket_num_patches)
|
|
656
|
+
|
|
657
|
+
x_padded = jnp.pad(x, ((0, bucket_num_patches - x.shape[0]), (0, 0)))
|
|
658
|
+
|
|
659
|
+
return x_padded, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens, num_tokens
|
|
660
|
+
|
|
661
|
+
def compute_hidden_states(self, x: jax.Array, window_index: jax.Array,
|
|
662
|
+
rotary_pos_emb: jax.Array, cu_seqlens: jax.Array,
|
|
663
|
+
cu_window_seqlens: jax.Array) -> jax.Array:
|
|
664
|
+
hidden_states = self.patch_embed(x)
|
|
665
|
+
|
|
666
|
+
# num of patches
|
|
667
|
+
seq_len = x.shape[0]
|
|
641
668
|
|
|
642
669
|
hidden_states = hidden_states.reshape(
|
|
643
670
|
seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
|
@@ -664,6 +691,48 @@ class Qwen2_5_VisionTransformer(nnx.Module):
|
|
|
664
691
|
hidden_states = hidden_states[reverse_indices, :]
|
|
665
692
|
return hidden_states
|
|
666
693
|
|
|
694
|
+
@jax.jit
|
|
695
|
+
def encode_padded_jit(self, x_padded, window_index, rotary_pos_emb,
|
|
696
|
+
cu_seqlens, cu_window_seqlens):
|
|
697
|
+
return self.compute_hidden_states(x_padded, window_index,
|
|
698
|
+
rotary_pos_emb, cu_seqlens,
|
|
699
|
+
cu_window_seqlens)
|
|
700
|
+
|
|
701
|
+
@partial(
|
|
702
|
+
jax.jit,
|
|
703
|
+
static_argnames=("grid_thw", ),
|
|
704
|
+
)
|
|
705
|
+
def encode_jit(self, x, grid_thw):
|
|
706
|
+
window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens = self.compute_aux_arrays(
|
|
707
|
+
grid_thw)
|
|
708
|
+
return self.compute_hidden_states(x, window_index, rotary_pos_emb,
|
|
709
|
+
cu_seqlens, cu_window_seqlens)
|
|
710
|
+
|
|
711
|
+
def __call__(self, x: jax.Array, grid_thw: tuple[tuple[int, int,
|
|
712
|
+
int]]) -> jax.Array:
|
|
713
|
+
# x: pixel_values: jax.Array
|
|
714
|
+
# """Shape:
|
|
715
|
+
# `(num_patches, num_channels * patch_size * patch_size)`
|
|
716
|
+
# """
|
|
717
|
+
|
|
718
|
+
# grid_thw: image_grid_thw: jax.Array
|
|
719
|
+
# """Shape: `(num_images, 3)`
|
|
720
|
+
# This should be in `(grid_t, grid_h, grid_w)` format.
|
|
721
|
+
# """
|
|
722
|
+
if self.enable_dynamic_image_sizes:
|
|
723
|
+
window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens = self.compute_aux_arrays(
|
|
724
|
+
grid_thw)
|
|
725
|
+
x_padded, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens, num_tokens = self.pad_inputs(
|
|
726
|
+
x, window_index, rotary_pos_emb, cu_seqlens, cu_window_seqlens)
|
|
727
|
+
|
|
728
|
+
hidden_states = self.encode_padded_jit(x_padded, window_index,
|
|
729
|
+
rotary_pos_emb, cu_seqlens,
|
|
730
|
+
cu_window_seqlens)
|
|
731
|
+
return hidden_states[:num_tokens]
|
|
732
|
+
|
|
733
|
+
else:
|
|
734
|
+
return self.encode_jit(x, grid_thw)
|
|
735
|
+
|
|
667
736
|
|
|
668
737
|
class Qwen2_5_VLForConditionalGeneration(nnx.Module):
|
|
669
738
|
|
|
@@ -888,10 +957,6 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
|
|
|
888
957
|
# "video"] = self._parse_and_validate_video_input(**kwargs)
|
|
889
958
|
return mm_input_by_modality
|
|
890
959
|
|
|
891
|
-
@partial(
|
|
892
|
-
jax.jit,
|
|
893
|
-
static_argnames=("image_grid_thw", ),
|
|
894
|
-
)
|
|
895
960
|
def get_single_image_embedding(self, image_pixel_values, image_grid_thw):
|
|
896
961
|
return self.visual(image_pixel_values, (image_grid_thw, ))
|
|
897
962
|
|
|
@@ -1072,33 +1137,82 @@ class Qwen2_5_VLForConditionalGeneration(nnx.Module):
|
|
|
1072
1137
|
self,
|
|
1073
1138
|
run_compilation_fn: Callable,
|
|
1074
1139
|
) -> None:
|
|
1075
|
-
image_shapes = []
|
|
1076
|
-
if (warmup_config := self.vllm_config.additional_config.get(
|
|
1077
|
-
"vision_warmup_config")):
|
|
1078
|
-
image_shapes = warmup_config.get("image_shapes")
|
|
1079
|
-
|
|
1080
1140
|
vc = self.vllm_config.model_config.hf_config.vision_config
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1086
|
-
|
|
1087
|
-
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1141
|
+
patch_input_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
|
|
1142
|
+
if self.visual.enable_dynamic_image_sizes:
|
|
1143
|
+
spatial_merge_unit = vc.spatial_merge_size**2
|
|
1144
|
+
max_num_batched_tokens = self.vllm_config.scheduler_config.max_num_batched_tokens
|
|
1145
|
+
mm_kwargs = self.vllm_config.model_config.multimodal_config.mm_processor_kwargs or {}
|
|
1146
|
+
limit_pixels = float(mm_kwargs.get("max_pixels", float('inf')))
|
|
1147
|
+
|
|
1148
|
+
max_patches = int(
|
|
1149
|
+
min(max_num_batched_tokens * spatial_merge_unit,
|
|
1150
|
+
limit_pixels / (vc.patch_size**2)))
|
|
1151
|
+
|
|
1152
|
+
num_patches_paddings = [
|
|
1153
|
+
1 << i for i in range(4, (max_patches - 1).bit_length() + 1)
|
|
1154
|
+
]
|
|
1155
|
+
rotary_dim = vc.hidden_size // vc.num_heads // 2
|
|
1156
|
+
vit_merger_window_size = (vc.window_size //
|
|
1157
|
+
vc.spatial_merge_size // vc.patch_size)
|
|
1158
|
+
|
|
1159
|
+
for num_patches in num_patches_paddings:
|
|
1160
|
+
dummy_x_padded = jnp.ones(
|
|
1161
|
+
(num_patches, patch_input_dim),
|
|
1162
|
+
dtype=self.vllm_config.model_config.dtype)
|
|
1163
|
+
|
|
1164
|
+
num_tokens = num_patches // spatial_merge_unit
|
|
1165
|
+
dummy_window_index = jnp.arange(num_tokens, dtype=jnp.int32)
|
|
1166
|
+
|
|
1167
|
+
dummy_rotary_pos_emb = jnp.ones(
|
|
1168
|
+
(num_patches, rotary_dim),
|
|
1169
|
+
dtype=self.vllm_config.model_config.dtype)
|
|
1170
|
+
|
|
1171
|
+
dummy_cu_seqlens = jnp.array([0, num_patches, num_patches],
|
|
1172
|
+
dtype=jnp.int32)
|
|
1173
|
+
|
|
1174
|
+
max_windows = (num_tokens // vit_merger_window_size) + 2
|
|
1175
|
+
patches_per_window = (vit_merger_window_size**
|
|
1176
|
+
2) * spatial_merge_unit
|
|
1177
|
+
dummy_cu_window_seqlens = jnp.arange(
|
|
1178
|
+
max_windows + 1, dtype=jnp.int32) * patches_per_window
|
|
1179
|
+
dummy_cu_window_seqlens = jnp.minimum(dummy_cu_window_seqlens,
|
|
1180
|
+
num_patches)
|
|
1181
|
+
|
|
1182
|
+
run_compilation_fn("vision_encoder_padded",
|
|
1183
|
+
self.visual.encode_padded_jit,
|
|
1184
|
+
dummy_x_padded,
|
|
1185
|
+
dummy_window_index,
|
|
1186
|
+
dummy_rotary_pos_emb,
|
|
1187
|
+
dummy_cu_seqlens,
|
|
1188
|
+
dummy_cu_window_seqlens,
|
|
1189
|
+
num_patches=num_patches)
|
|
1190
|
+
else:
|
|
1191
|
+
image_shapes = []
|
|
1192
|
+
if (warmup_config := self.vllm_config.additional_config.get(
|
|
1193
|
+
"vision_warmup_config")):
|
|
1194
|
+
image_shapes = warmup_config.get("image_shapes")
|
|
1195
|
+
|
|
1196
|
+
factor = vc.patch_size * vc.spatial_merge_size
|
|
1197
|
+
for input_hw in image_shapes:
|
|
1198
|
+
if not isinstance(input_hw, list) or len(input_hw) != 2:
|
|
1199
|
+
logger.warning(f"Skipping invalid shape {input_hw}.")
|
|
1200
|
+
continue
|
|
1201
|
+
h_input, w_input = input_hw
|
|
1202
|
+
h_processed = round(h_input / factor) * factor
|
|
1203
|
+
w_processed = round(w_input / factor) * factor
|
|
1204
|
+
t, h, w = 1, h_processed // vc.patch_size, w_processed // vc.patch_size
|
|
1205
|
+
grid_thw = (t, h, w)
|
|
1206
|
+
num_patches = t * h * w
|
|
1207
|
+
|
|
1208
|
+
dummy_pixel_values = jnp.ones(
|
|
1209
|
+
(num_patches, patch_input_dim),
|
|
1210
|
+
self.vllm_config.model_config.dtype,
|
|
1211
|
+
)
|
|
1212
|
+
dummy_grid_thw = (grid_thw, )
|
|
1099
1213
|
|
|
1100
|
-
|
|
1101
|
-
|
|
1102
|
-
|
|
1103
|
-
|
|
1104
|
-
|
|
1214
|
+
run_compilation_fn("vision_encoder",
|
|
1215
|
+
self.visual.encode_jit,
|
|
1216
|
+
dummy_pixel_values,
|
|
1217
|
+
dummy_grid_thw,
|
|
1218
|
+
image_shape=input_hw)
|
|
@@ -154,12 +154,9 @@ def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
|
|
|
154
154
|
logger.info(f"Memory usage before applying quantization of params: "
|
|
155
155
|
f"hbm={utils.hbm_usage_gb(jax.local_devices())}Gb")
|
|
156
156
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
# Handle the case where kv_cache_dtype is "auto"
|
|
161
|
-
if kv_cache_jnp_dtype is None:
|
|
162
|
-
assert kv_cache_dtype == "auto", "kv_cache_dtype must be 'auto' if kv_cache_jnp_dtype is None"
|
|
157
|
+
if kv_cache_dtype != "auto":
|
|
158
|
+
kv_cache_jnp_dtype = utils.to_jax_dtype(kv_cache_dtype)
|
|
159
|
+
else:
|
|
163
160
|
kv_cache_jnp_dtype = DEFAULT_KV_CACHE_DTYPE
|
|
164
161
|
|
|
165
162
|
kv_caches = create_kv_caches(
|
|
@@ -169,9 +166,11 @@ def qwix_quantize_nnx_model(model: nnx.Module, qwix_config: List[dict],
|
|
|
169
166
|
head_size=kv_cache_head_size,
|
|
170
167
|
mesh=mesh,
|
|
171
168
|
layer_names=[f"layer.{i}" for i in range(num_hidden_layers)],
|
|
172
|
-
cache_dtype=kv_cache_jnp_dtype
|
|
169
|
+
cache_dtype=kv_cache_jnp_dtype,
|
|
170
|
+
use_mla=model.vllm_config.model_config.use_mla,
|
|
171
|
+
)
|
|
173
172
|
|
|
174
|
-
dp_size =
|
|
173
|
+
dp_size = model.vllm_config.sharding_config.total_dp_size
|
|
175
174
|
|
|
176
175
|
# NOTE: the inputs don't need to match the actual ones, as long as the consumed weights are the same
|
|
177
176
|
input_ids = jax.random.randint(rng,
|