tpu-inference 0.11.1.dev202511270815__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +22 -1
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +167 -97
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +31 -9
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +280 -210
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +77 -36
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +91 -31
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -71
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +158 -63
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +53 -30
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +105 -57
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +65 -19
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +65 -52
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
- tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511270815.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511270815.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,605 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from functools import partial
|
|
16
|
+
from unittest.mock import MagicMock, patch
|
|
17
|
+
|
|
18
|
+
import jax
|
|
19
|
+
import jax.numpy as jnp
|
|
20
|
+
import numpy as np
|
|
21
|
+
import pytest
|
|
22
|
+
from flax import nnx
|
|
23
|
+
from flax.typing import PRNGKey
|
|
24
|
+
from jax.sharding import Mesh
|
|
25
|
+
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import \
|
|
26
|
+
Qwen2_5_VLConfig
|
|
27
|
+
from vllm.config import (CacheConfig, DeviceConfig, MultiModalConfig,
|
|
28
|
+
ParallelConfig, SchedulerConfig)
|
|
29
|
+
|
|
30
|
+
# Import the module itself to allow patching
|
|
31
|
+
# Corrected imports for the code under test
|
|
32
|
+
from tpu_inference.models.jax.qwen2_5_vl import (
|
|
33
|
+
AttentionMetadata, Qwen2_5_VisionAttention, Qwen2_5_VisionBlock,
|
|
34
|
+
Qwen2_5_VisionMLP, Qwen2_5_VisionPatchEmbed, Qwen2_5_VisionPatchMerger,
|
|
35
|
+
Qwen2_5_VisionRotaryEmbedding, Qwen2_5_VisionTransformer,
|
|
36
|
+
Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLImagePixelInputs, SegmentIds,
|
|
37
|
+
apply_rotary_pos_emb_vision, generate_window_segment_ids)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# --- Configuration Mocking ---
|
|
41
|
+
class MockModelConfig:
|
|
42
|
+
|
|
43
|
+
def __init__(self, hf_config, dtype):
|
|
44
|
+
self.hf_config = hf_config
|
|
45
|
+
self.dtype = dtype
|
|
46
|
+
self.multimodal_config = MultiModalConfig(
|
|
47
|
+
image_input_type="pixel",
|
|
48
|
+
image_token_id=hf_config.image_token_id,
|
|
49
|
+
image_input_shape=None)
|
|
50
|
+
self.model = "mock_qwen2_5_vl"
|
|
51
|
+
# Add other attributes if needed by the code
|
|
52
|
+
self.tokenizer = "mock_tokenizer"
|
|
53
|
+
self.tokenizer_mode = "auto"
|
|
54
|
+
self.trust_remote_code = True
|
|
55
|
+
self.seed = 0
|
|
56
|
+
|
|
57
|
+
def is_multimodal_model(self):
|
|
58
|
+
return True
|
|
59
|
+
|
|
60
|
+
def get_hidden_size(self):
|
|
61
|
+
return self.hf_config.hidden_size
|
|
62
|
+
|
|
63
|
+
def get_head_size(self):
|
|
64
|
+
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class MockVllmConfig:
|
|
68
|
+
"""A mock VllmConfig sufficient for testing the Qwen2.5 VL model."""
|
|
69
|
+
|
|
70
|
+
def __init__(self, tie_word_embeddings: bool = False):
|
|
71
|
+
vision_config = {
|
|
72
|
+
"hidden_size": 16,
|
|
73
|
+
"intermediate_size": 32,
|
|
74
|
+
"patch_size": 14,
|
|
75
|
+
"image_size": 28,
|
|
76
|
+
"temporal_patch_size": 2,
|
|
77
|
+
"in_channels": 3,
|
|
78
|
+
"window_size": 28,
|
|
79
|
+
"spatial_merge_size": 2,
|
|
80
|
+
"fullatt_block_indexes": [0],
|
|
81
|
+
"out_hidden_size": 24,
|
|
82
|
+
"depth": 2,
|
|
83
|
+
"hidden_act": "gelu",
|
|
84
|
+
"num_heads": 2,
|
|
85
|
+
}
|
|
86
|
+
hf_config = Qwen2_5_VLConfig(
|
|
87
|
+
vision_config=vision_config,
|
|
88
|
+
hidden_size=16,
|
|
89
|
+
num_hidden_layers=2,
|
|
90
|
+
num_attention_heads=2,
|
|
91
|
+
num_key_value_heads=2,
|
|
92
|
+
intermediate_size=32,
|
|
93
|
+
rms_norm_eps=1e-6,
|
|
94
|
+
image_token_id=200000,
|
|
95
|
+
video_token_id=200001,
|
|
96
|
+
tie_word_embeddings=tie_word_embeddings,
|
|
97
|
+
vocab_size=32000,
|
|
98
|
+
rope_theta=1000000.0,
|
|
99
|
+
)
|
|
100
|
+
self.model_config = MockModelConfig(hf_config, jnp.bfloat16)
|
|
101
|
+
self.cache_config = MagicMock(spec=CacheConfig)
|
|
102
|
+
self.parallelism_config = MagicMock(spec=ParallelConfig)
|
|
103
|
+
self.scheduler_config = MagicMock(spec=SchedulerConfig)
|
|
104
|
+
self.device_config = MagicMock(spec=DeviceConfig)
|
|
105
|
+
self.load_config = MagicMock()
|
|
106
|
+
self.extra_configs = {}
|
|
107
|
+
self.additional_config = {}
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@pytest.fixture(scope="module")
|
|
111
|
+
def mesh():
|
|
112
|
+
"""Creates a mesh with all required axes for testing."""
|
|
113
|
+
if not jax.devices():
|
|
114
|
+
pytest.skip("No JAX devices available for mesh creation.")
|
|
115
|
+
devices = np.array(jax.local_devices())
|
|
116
|
+
return Mesh(devices.reshape((len(devices), 1, 1)),
|
|
117
|
+
axis_names=('data', 'attn_dp', 'model'))
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@pytest.fixture
|
|
121
|
+
def rng() -> PRNGKey:
|
|
122
|
+
"""Provides a reusable JAX PRNGKey."""
|
|
123
|
+
return jax.random.PRNGKey(42)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
@pytest.fixture
|
|
127
|
+
def mock_vllm_config() -> MockVllmConfig:
|
|
128
|
+
return MockVllmConfig()
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
@pytest.fixture
|
|
132
|
+
def rngs(rng: PRNGKey) -> nnx.Rngs:
|
|
133
|
+
return nnx.Rngs(params=rng)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
# --- Test Classes ---
|
|
137
|
+
class TestUtils:
|
|
138
|
+
|
|
139
|
+
def test_apply_rotary_pos_emb_vision(self, rng: PRNGKey):
|
|
140
|
+
B, T, N, H = 1, 10, 2, 8
|
|
141
|
+
x = jax.random.normal(rng, (B, T, N, H))
|
|
142
|
+
rotary_pos_emb = jax.random.normal(rng, (T, H // 2))
|
|
143
|
+
x_rotated = apply_rotary_pos_emb_vision(x, rotary_pos_emb)
|
|
144
|
+
assert x_rotated.shape == (B, T, N, H)
|
|
145
|
+
|
|
146
|
+
def test_generate_window_segment_ids(self):
|
|
147
|
+
cu_seqlens = jnp.array([0, 5, 10])
|
|
148
|
+
seq_len = 10
|
|
149
|
+
padded_seq_len = 16
|
|
150
|
+
segment_ids = generate_window_segment_ids(cu_seqlens, seq_len,
|
|
151
|
+
padded_seq_len)
|
|
152
|
+
assert isinstance(segment_ids, SegmentIds)
|
|
153
|
+
assert segment_ids.q.shape == (1, padded_seq_len)
|
|
154
|
+
assert segment_ids.kv.shape == (1, padded_seq_len)
|
|
155
|
+
expected_q = np.array(
|
|
156
|
+
[[1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0]])
|
|
157
|
+
np.testing.assert_array_equal(segment_ids.q, expected_q)
|
|
158
|
+
np.testing.assert_array_equal(segment_ids.kv, expected_q)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class TestQwen2_5_VisionMLP:
|
|
162
|
+
|
|
163
|
+
def test_forward(self, mock_vllm_config: MockVllmConfig, rngs: nnx.Rngs):
|
|
164
|
+
config = mock_vllm_config.model_config.hf_config.vision_config
|
|
165
|
+
dtype = mock_vllm_config.model_config.dtype
|
|
166
|
+
mlp = Qwen2_5_VisionMLP(config, dtype, rngs)
|
|
167
|
+
x = jnp.ones((5, config.hidden_size), dtype=dtype)
|
|
168
|
+
y = mlp(x)
|
|
169
|
+
assert y.shape == (5, config.hidden_size)
|
|
170
|
+
assert y.dtype == dtype
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
class TestQwen2_5_VisionAttention:
|
|
174
|
+
|
|
175
|
+
@patch('tpu_inference.models.jax.qwen2_5_vl.sharded_flash_attention')
|
|
176
|
+
def test_forward_fullattn(self, mock_flash_attention: MagicMock,
|
|
177
|
+
mock_vllm_config: MockVllmConfig, rngs: nnx.Rngs,
|
|
178
|
+
mesh: Mesh, rng: PRNGKey):
|
|
179
|
+
attn_module = Qwen2_5_VisionAttention(
|
|
180
|
+
mock_vllm_config.model_config.hf_config,
|
|
181
|
+
mock_vllm_config.model_config.dtype, rngs, mesh)
|
|
182
|
+
B, T, D = 1, 10, attn_module.hidden_size
|
|
183
|
+
# sharded_flash_attention is a factory, so we mock the returned function
|
|
184
|
+
mock_attn_fn = MagicMock(return_value=jnp.ones((B,
|
|
185
|
+
attn_module.num_heads,
|
|
186
|
+
128,
|
|
187
|
+
attn_module.head_dim)))
|
|
188
|
+
attn_module.flash_attention = mock_attn_fn
|
|
189
|
+
x = jax.random.normal(rng, (T, B, D))
|
|
190
|
+
rotary_pos_emb = jax.random.normal(rng, (T, attn_module.head_dim // 2))
|
|
191
|
+
cu_seqlens = jnp.array([0, 5])
|
|
192
|
+
|
|
193
|
+
y_full = attn_module(x,
|
|
194
|
+
rotary_pos_emb,
|
|
195
|
+
cu_window_seqlens=cu_seqlens,
|
|
196
|
+
use_fullattn=True)
|
|
197
|
+
assert y_full.shape == (T, B, D)
|
|
198
|
+
mock_attn_fn.assert_called_once()
|
|
199
|
+
assert mock_attn_fn.call_args[0][3].q.shape == (1, 128)
|
|
200
|
+
|
|
201
|
+
@patch('tpu_inference.models.jax.qwen2_5_vl.sharded_flash_attention')
|
|
202
|
+
def test_forward_windowed(self, mock_flash_attention: MagicMock,
|
|
203
|
+
mock_vllm_config: MockVllmConfig, rngs: nnx.Rngs,
|
|
204
|
+
mesh: Mesh, rng: PRNGKey):
|
|
205
|
+
attn_module = Qwen2_5_VisionAttention(
|
|
206
|
+
mock_vllm_config.model_config.hf_config,
|
|
207
|
+
mock_vllm_config.model_config.dtype, rngs, mesh)
|
|
208
|
+
B, T, D = 1, 10, attn_module.hidden_size
|
|
209
|
+
mock_attn_fn = MagicMock(return_value=jnp.ones((B,
|
|
210
|
+
attn_module.num_heads,
|
|
211
|
+
128,
|
|
212
|
+
attn_module.head_dim)))
|
|
213
|
+
attn_module.flash_attention = mock_attn_fn
|
|
214
|
+
x = jax.random.normal(rng, (T, B, D))
|
|
215
|
+
rotary_pos_emb = jax.random.normal(rng, (T, attn_module.head_dim // 2))
|
|
216
|
+
cu_window_seqlens = jnp.array([0, 5, 10])
|
|
217
|
+
|
|
218
|
+
y_window = attn_module(x,
|
|
219
|
+
rotary_pos_emb,
|
|
220
|
+
cu_window_seqlens=cu_window_seqlens,
|
|
221
|
+
use_fullattn=False)
|
|
222
|
+
assert y_window.shape == (T, B, D)
|
|
223
|
+
mock_attn_fn.assert_called_once()
|
|
224
|
+
assert mock_attn_fn.call_args[0][3].q.shape == (1, 128)
|
|
225
|
+
|
|
226
|
+
def test_batch_fail(self, mock_vllm_config: MockVllmConfig, rngs: nnx.Rngs,
|
|
227
|
+
mesh: Mesh, rng: PRNGKey):
|
|
228
|
+
attn_module = Qwen2_5_VisionAttention(
|
|
229
|
+
mock_vllm_config.model_config.hf_config,
|
|
230
|
+
mock_vllm_config.model_config.dtype, rngs, mesh)
|
|
231
|
+
T, B, D = 10, 2, attn_module.hidden_size
|
|
232
|
+
x = jax.random.normal(rng, (T, B, D))
|
|
233
|
+
rotary_pos_emb = jax.random.normal(rng, (T, attn_module.head_dim // 2))
|
|
234
|
+
with pytest.raises(
|
|
235
|
+
AssertionError,
|
|
236
|
+
match="Vision attention currently only supports batch size 1"):
|
|
237
|
+
attn_module(x, rotary_pos_emb, use_fullattn=True)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
class TestQwen2_5_VisionBlock:
|
|
241
|
+
|
|
242
|
+
@patch('tpu_inference.models.jax.qwen2_5_vl.Qwen2_5_VisionMLP',
|
|
243
|
+
autospec=True)
|
|
244
|
+
@patch('tpu_inference.models.jax.qwen2_5_vl.Qwen2_5_VisionAttention',
|
|
245
|
+
autospec=True)
|
|
246
|
+
def test_forward(self, MockAttention: MagicMock, MockMLP: MagicMock,
|
|
247
|
+
mock_vllm_config: MockVllmConfig, rngs: nnx.Rngs,
|
|
248
|
+
mesh: Mesh, rng: PRNGKey):
|
|
249
|
+
config = mock_vllm_config.model_config.hf_config
|
|
250
|
+
dtype = mock_vllm_config.model_config.dtype
|
|
251
|
+
D = config.vision_config.hidden_size
|
|
252
|
+
T, B = 10, 1
|
|
253
|
+
|
|
254
|
+
mock_attn_instance = MockAttention.return_value
|
|
255
|
+
mock_attn_instance.return_value = jnp.zeros((T, B, D), dtype=dtype)
|
|
256
|
+
mock_mlp_instance = MockMLP.return_value
|
|
257
|
+
mock_mlp_instance.return_value = jnp.zeros((T, B, D), dtype=dtype)
|
|
258
|
+
|
|
259
|
+
block = Qwen2_5_VisionBlock(config, dtype, rngs, mesh)
|
|
260
|
+
x = jax.random.normal(rng, (T, B, D))
|
|
261
|
+
rotary_pos_emb = jax.random.normal(
|
|
262
|
+
rng, (T, config.vision_config.hidden_size //
|
|
263
|
+
config.vision_config.num_heads // 2))
|
|
264
|
+
|
|
265
|
+
y = block(x, rotary_pos_emb, use_fullattn=True)
|
|
266
|
+
assert y.shape == (T, B, D)
|
|
267
|
+
mock_attn_instance.assert_called_once()
|
|
268
|
+
mock_mlp_instance.assert_called_once()
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class TestQwen2_5_VisionPatchEmbed:
|
|
272
|
+
|
|
273
|
+
def test_forward(self, mock_vllm_config: MockVllmConfig, rngs: nnx.Rngs,
|
|
274
|
+
rng: PRNGKey):
|
|
275
|
+
vc = mock_vllm_config.model_config.hf_config.vision_config
|
|
276
|
+
dtype = mock_vllm_config.model_config.dtype
|
|
277
|
+
patch_embed = Qwen2_5_VisionPatchEmbed(
|
|
278
|
+
rngs,
|
|
279
|
+
patch_size=vc.patch_size,
|
|
280
|
+
temporal_patch_size=vc.temporal_patch_size,
|
|
281
|
+
in_channels=vc.in_channels,
|
|
282
|
+
hidden_size=vc.hidden_size,
|
|
283
|
+
dtype=dtype)
|
|
284
|
+
num_patches = 4
|
|
285
|
+
patch_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
|
|
286
|
+
x = jax.random.normal(rng, (num_patches, patch_dim))
|
|
287
|
+
y = patch_embed(x)
|
|
288
|
+
assert y.shape == (num_patches, vc.hidden_size)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class TestQwen2_5_VisionPatchMerger:
|
|
292
|
+
|
|
293
|
+
def test_forward(self, mock_vllm_config: MockVllmConfig, rngs: nnx.Rngs,
|
|
294
|
+
rng: PRNGKey):
|
|
295
|
+
vc = mock_vllm_config.model_config.hf_config.vision_config
|
|
296
|
+
dtype = mock_vllm_config.model_config.dtype
|
|
297
|
+
merger = Qwen2_5_VisionPatchMerger(
|
|
298
|
+
d_model=vc.out_hidden_size,
|
|
299
|
+
context_dim=vc.hidden_size,
|
|
300
|
+
norm_layer=partial(nnx.RMSNorm, epsilon=1e-6),
|
|
301
|
+
spatial_merge_size=vc.spatial_merge_size,
|
|
302
|
+
dtype=dtype,
|
|
303
|
+
rngs=rngs)
|
|
304
|
+
x = jax.random.normal(rng,
|
|
305
|
+
(5, vc.spatial_merge_size**2, vc.hidden_size))
|
|
306
|
+
y = merger(x)
|
|
307
|
+
assert y.shape == (5, vc.out_hidden_size)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
class TestQwen2_5_VisionRotaryEmbedding:
|
|
311
|
+
|
|
312
|
+
def test_forward(self):
|
|
313
|
+
dim = 16
|
|
314
|
+
seqlen = 10
|
|
315
|
+
rotary_emb = Qwen2_5_VisionRotaryEmbedding(dim=dim)
|
|
316
|
+
emb = rotary_emb(seqlen)
|
|
317
|
+
assert emb.shape == (seqlen, dim // 2)
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
class TestQwen2_5_VisionTransformer:
|
|
321
|
+
|
|
322
|
+
@pytest.fixture
|
|
323
|
+
def vision_transformer(self, mock_vllm_config: MockVllmConfig,
|
|
324
|
+
rngs: nnx.Rngs, mesh: Mesh):
|
|
325
|
+
return Qwen2_5_VisionTransformer(mock_vllm_config, rngs, mesh)
|
|
326
|
+
|
|
327
|
+
def test_rotary_pos_emb_thw(self,
|
|
328
|
+
vision_transformer: Qwen2_5_VisionTransformer):
|
|
329
|
+
t, h, w = 2, 4, 4
|
|
330
|
+
emb = vision_transformer.rotary_pos_emb_thw(t, h, w)
|
|
331
|
+
vc = vision_transformer.config
|
|
332
|
+
sm = vc.spatial_merge_size
|
|
333
|
+
head_dim_half = (vc.hidden_size // vc.num_heads) // 2
|
|
334
|
+
expected_shape = (t * (h // sm) * (w // sm), sm * sm, head_dim_half)
|
|
335
|
+
assert emb.shape == expected_shape
|
|
336
|
+
|
|
337
|
+
def test_get_window_index_thw(
|
|
338
|
+
self, vision_transformer: Qwen2_5_VisionTransformer):
|
|
339
|
+
grid_t, grid_h, grid_w = 1, 8, 8
|
|
340
|
+
index_new, cu_seqlens_tmp = vision_transformer.get_window_index_thw(
|
|
341
|
+
grid_t, grid_h, grid_w)
|
|
342
|
+
vc = vision_transformer.config
|
|
343
|
+
sm = vc.spatial_merge_size
|
|
344
|
+
num_valid_indices = grid_t * (grid_h // sm) * (grid_w // sm)
|
|
345
|
+
assert index_new.shape == (num_valid_indices, )
|
|
346
|
+
assert jnp.all(index_new >= 0)
|
|
347
|
+
|
|
348
|
+
def test_get_rope_by_thw(self,
|
|
349
|
+
vision_transformer: Qwen2_5_VisionTransformer):
|
|
350
|
+
t, h, w = 1, 8, 8
|
|
351
|
+
res = vision_transformer.get_rope_by_thw(t, h, w)
|
|
352
|
+
assert isinstance(res, tuple)
|
|
353
|
+
assert len(res) == 4
|
|
354
|
+
rotary_pos_emb_thw, window_index_thw, cu_seqlens_window_thw, cu_seqlens_thw = res
|
|
355
|
+
|
|
356
|
+
vc = vision_transformer.config
|
|
357
|
+
sm = vc.spatial_merge_size
|
|
358
|
+
# The rotary embedding output for each position is head_dim // 2
|
|
359
|
+
head_dim_rope = (vc.hidden_size // vc.num_heads) // 2
|
|
360
|
+
expected_len = window_index_thw.shape[0] * sm * sm
|
|
361
|
+
assert rotary_pos_emb_thw.shape == (expected_len, head_dim_rope)
|
|
362
|
+
|
|
363
|
+
@pytest.mark.parametrize("enable_dynamic_image_sizes", [False, True])
|
|
364
|
+
def test_call(self, mock_vllm_config: MockVllmConfig, rngs: nnx.Rngs,
|
|
365
|
+
mesh: Mesh, rng: PRNGKey, enable_dynamic_image_sizes: bool):
|
|
366
|
+
mock_vllm_config.additional_config = {
|
|
367
|
+
"enable_dynamic_image_sizes": enable_dynamic_image_sizes
|
|
368
|
+
}
|
|
369
|
+
vision_transformer = Qwen2_5_VisionTransformer(mock_vllm_config, rngs,
|
|
370
|
+
mesh)
|
|
371
|
+
# Mock the flash_attention call to avoid sharding errors in test environment
|
|
372
|
+
for block in vision_transformer.blocks:
|
|
373
|
+
# The mock should return a tensor of the same shape as the query 'q'
|
|
374
|
+
block.attn.flash_attention = MagicMock(
|
|
375
|
+
side_effect=lambda q, k, v, seg: jnp.ones_like(q))
|
|
376
|
+
|
|
377
|
+
vc = vision_transformer.config
|
|
378
|
+
t_pix, h_pix, w_pix = 2, 84, 28
|
|
379
|
+
|
|
380
|
+
# The number of patches is calculated from the pixel dimensions of the image/video
|
|
381
|
+
num_patches = (t_pix // vc.temporal_patch_size) * \
|
|
382
|
+
(h_pix // vc.patch_size) * \
|
|
383
|
+
(w_pix // vc.patch_size)
|
|
384
|
+
|
|
385
|
+
patch_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
|
|
386
|
+
x = jax.random.normal(rng, (num_patches, patch_dim))
|
|
387
|
+
|
|
388
|
+
# The grid_thw should be in terms of patch grid dimensions, not pixels
|
|
389
|
+
t_grid = t_pix // vc.temporal_patch_size
|
|
390
|
+
h_grid = h_pix // vc.patch_size
|
|
391
|
+
w_grid = w_pix // vc.patch_size
|
|
392
|
+
grid_thw = ((t_grid, h_grid, w_grid), )
|
|
393
|
+
|
|
394
|
+
embeddings = vision_transformer(x, grid_thw)
|
|
395
|
+
|
|
396
|
+
# The number of output tokens is determined by the grid dimensions and spatial merge size.
|
|
397
|
+
expected_len = t_grid * (h_grid // vc.spatial_merge_size) * (
|
|
398
|
+
w_grid // vc.spatial_merge_size)
|
|
399
|
+
assert embeddings.shape == (expected_len, vc.out_hidden_size)
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
class TestQwen2_5_VLForConditionalGeneration:
|
|
403
|
+
|
|
404
|
+
@pytest.fixture
|
|
405
|
+
def model(self, mock_vllm_config: MockVllmConfig, rng: PRNGKey,
|
|
406
|
+
mesh: Mesh):
|
|
407
|
+
with patch('tpu_inference.models.jax.qwen2_5_vl.Qwen2_5_VisionTransformer', autospec=True) as MockVision, \
|
|
408
|
+
patch('tpu_inference.models.jax.qwen2_5_vl.Qwen2ForCausalLM', autospec=True) as MockLM:
|
|
409
|
+
mock_visual = MockVision.return_value
|
|
410
|
+
mock_visual.dtype = mock_vllm_config.model_config.dtype
|
|
411
|
+
mock_visual.config = mock_vllm_config.model_config.hf_config.vision_config
|
|
412
|
+
mock_visual.spatial_merge_size = mock_vllm_config.model_config.hf_config.vision_config.spatial_merge_size
|
|
413
|
+
|
|
414
|
+
model = Qwen2_5_VLForConditionalGeneration(mock_vllm_config, rng,
|
|
415
|
+
mesh)
|
|
416
|
+
# Directly assign mocked instances
|
|
417
|
+
model.visual = mock_visual
|
|
418
|
+
model.language_model = MockLM.return_value
|
|
419
|
+
yield model
|
|
420
|
+
|
|
421
|
+
def test_validate_and_reshape_mm_tensor(
|
|
422
|
+
self, model: Qwen2_5_VLForConditionalGeneration):
|
|
423
|
+
data_list = [np.ones((2, 4)), np.ones((3, 4))]
|
|
424
|
+
reshaped_list = model._validate_and_reshape_mm_tensor(
|
|
425
|
+
data_list, "test_list")
|
|
426
|
+
assert reshaped_list.shape == (5, 4)
|
|
427
|
+
assert isinstance(reshaped_list, jax.Array)
|
|
428
|
+
|
|
429
|
+
data_2d = np.ones((5, 4))
|
|
430
|
+
reshaped_2d = model._validate_and_reshape_mm_tensor(data_2d, "test_2d")
|
|
431
|
+
assert reshaped_2d.shape == (5, 4)
|
|
432
|
+
|
|
433
|
+
data_3d = np.ones((2, 5, 4))
|
|
434
|
+
reshaped_3d = model._validate_and_reshape_mm_tensor(data_3d, "test_3d")
|
|
435
|
+
assert reshaped_3d.shape == (10, 4)
|
|
436
|
+
|
|
437
|
+
with pytest.raises(ValueError, match="Incorrect type of test_invalid"):
|
|
438
|
+
model._validate_and_reshape_mm_tensor("invalid", "test_invalid")
|
|
439
|
+
|
|
440
|
+
def test_parse_and_validate_image_input(
|
|
441
|
+
self, model: Qwen2_5_VLForConditionalGeneration):
|
|
442
|
+
grid = ((2, 28, 28), )
|
|
443
|
+
vc = model.config.vision_config
|
|
444
|
+
patch_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
|
|
445
|
+
pixel_values = np.ones((4, patch_dim))
|
|
446
|
+
|
|
447
|
+
parsed = model._parse_and_validate_image_input(
|
|
448
|
+
grid, pixel_values=pixel_values)
|
|
449
|
+
assert parsed is not None
|
|
450
|
+
assert parsed['type'] == "pixel_values"
|
|
451
|
+
assert parsed['pixel_values'].shape == (4, patch_dim)
|
|
452
|
+
assert parsed['image_grid_thw'] == grid
|
|
453
|
+
|
|
454
|
+
parsed_none = model._parse_and_validate_image_input(grid)
|
|
455
|
+
assert parsed_none is None
|
|
456
|
+
|
|
457
|
+
def test_parse_and_validate_multimodal_inputs(
|
|
458
|
+
self, model: Qwen2_5_VLForConditionalGeneration):
|
|
459
|
+
grid = ((2, 28, 28), )
|
|
460
|
+
vc = model.config.vision_config
|
|
461
|
+
patch_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
|
|
462
|
+
pixel_values = np.ones((4, patch_dim))
|
|
463
|
+
|
|
464
|
+
mm_inputs = model._parse_and_validate_multimodal_inputs(
|
|
465
|
+
grid, pixel_values=pixel_values)
|
|
466
|
+
assert "image" in mm_inputs
|
|
467
|
+
assert mm_inputs["image"]['type'] == "pixel_values"
|
|
468
|
+
|
|
469
|
+
mm_inputs_empty = model._parse_and_validate_multimodal_inputs(grid)
|
|
470
|
+
assert not mm_inputs_empty
|
|
471
|
+
|
|
472
|
+
def test_process_image_input_pixels(
|
|
473
|
+
self, model: Qwen2_5_VLForConditionalGeneration):
|
|
474
|
+
grid_thw = ((2, 28, 28), (2, 28, 28))
|
|
475
|
+
vc = model.config.vision_config
|
|
476
|
+
num_patches = 8 # 4 per image
|
|
477
|
+
patch_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
|
|
478
|
+
pixel_values = jnp.ones((num_patches, patch_dim))
|
|
479
|
+
image_input = Qwen2_5_VLImagePixelInputs(type="pixel_values",
|
|
480
|
+
pixel_values=pixel_values,
|
|
481
|
+
image_grid_thw=grid_thw)
|
|
482
|
+
|
|
483
|
+
tokens_per_image = (2 * 28 * 28) // (vc.spatial_merge_size**2)
|
|
484
|
+
mock_embeds = jnp.ones((tokens_per_image, vc.out_hidden_size))
|
|
485
|
+
model.visual.return_value = mock_embeds
|
|
486
|
+
|
|
487
|
+
embeddings = model._process_image_input(image_input)
|
|
488
|
+
assert isinstance(embeddings, tuple)
|
|
489
|
+
assert len(embeddings) == 2
|
|
490
|
+
assert embeddings[0].shape == (tokens_per_image, vc.out_hidden_size)
|
|
491
|
+
assert embeddings[1].shape == (tokens_per_image, vc.out_hidden_size)
|
|
492
|
+
assert model.visual.call_count == 2
|
|
493
|
+
|
|
494
|
+
def test_embed_multimodal(self, model: Qwen2_5_VLForConditionalGeneration):
|
|
495
|
+
grid_thw = ((2, 28, 28), )
|
|
496
|
+
vc = model.config.vision_config
|
|
497
|
+
patch_dim = vc.in_channels * vc.temporal_patch_size * vc.patch_size * vc.patch_size
|
|
498
|
+
pixel_values = np.ones((4, patch_dim))
|
|
499
|
+
tokens_per_image = (2 * 28 * 28) // (vc.spatial_merge_size**2)
|
|
500
|
+
mock_vision_output = jnp.ones((tokens_per_image, vc.out_hidden_size))
|
|
501
|
+
|
|
502
|
+
with patch.object(model,
|
|
503
|
+
'_process_image_input',
|
|
504
|
+
return_value=(mock_vision_output, )) as mock_process:
|
|
505
|
+
mm_embeds = model.embed_multimodal(grid_thw,
|
|
506
|
+
pixel_values=pixel_values)
|
|
507
|
+
mock_process.assert_called_once()
|
|
508
|
+
assert isinstance(mm_embeds, tuple)
|
|
509
|
+
assert len(mm_embeds) == 1
|
|
510
|
+
assert mm_embeds[0].shape == (tokens_per_image, vc.out_hidden_size)
|
|
511
|
+
|
|
512
|
+
mm_embeds_none = model.embed_multimodal(grid_thw)
|
|
513
|
+
assert len(mm_embeds_none) == 0
|
|
514
|
+
|
|
515
|
+
@patch('tpu_inference.models.jax.qwen2_5_vl.merge_multimodal_embeddings')
|
|
516
|
+
def test_embed_input_ids(self, mock_merge_embeddings: MagicMock,
|
|
517
|
+
model: Qwen2_5_VLForConditionalGeneration,
|
|
518
|
+
rng: PRNGKey):
|
|
519
|
+
input_ids = jax.random.randint(rng, (1, 10), 0,
|
|
520
|
+
model.config.vocab_size)
|
|
521
|
+
mock_text_embeds = jnp.ones((1, 10, model.config.hidden_size))
|
|
522
|
+
model.language_model.model = MagicMock()
|
|
523
|
+
model.language_model.model.embed = MagicMock(
|
|
524
|
+
return_value=mock_text_embeds)
|
|
525
|
+
|
|
526
|
+
embeds = model.embed_input_ids(input_ids, None)
|
|
527
|
+
np.testing.assert_array_equal(embeds, mock_text_embeds)
|
|
528
|
+
mock_merge_embeddings.assert_not_called()
|
|
529
|
+
|
|
530
|
+
empty_mm = jnp.ones((0, model.config.hidden_size), )
|
|
531
|
+
embeds_empty_mm = model.embed_input_ids(input_ids, empty_mm)
|
|
532
|
+
np.testing.assert_array_equal(embeds_empty_mm, mock_text_embeds)
|
|
533
|
+
mock_merge_embeddings.assert_not_called()
|
|
534
|
+
|
|
535
|
+
mm_embeds = jnp.ones((5, model.config.hidden_size))
|
|
536
|
+
mock_merged = jnp.ones((1, 15, model.config.hidden_size))
|
|
537
|
+
mock_merge_embeddings.return_value = mock_merged
|
|
538
|
+
|
|
539
|
+
embeds_mm = model.embed_input_ids(input_ids, mm_embeds)
|
|
540
|
+
np.testing.assert_array_equal(embeds_mm, mock_merged)
|
|
541
|
+
mock_merge_embeddings.assert_called_once_with(
|
|
542
|
+
input_ids, mock_text_embeds, mm_embeds,
|
|
543
|
+
[model.config.image_token_id, model.config.video_token_id])
|
|
544
|
+
|
|
545
|
+
def test_call(self, model: Qwen2_5_VLForConditionalGeneration,
|
|
546
|
+
rng: PRNGKey):
|
|
547
|
+
kv_caches = [MagicMock()]
|
|
548
|
+
input_ids = jax.random.randint(rng, (1, 10), 0,
|
|
549
|
+
model.config.vocab_size)
|
|
550
|
+
attn_meta = MagicMock(spec=AttentionMetadata)
|
|
551
|
+
mock_lm_output = ([MagicMock()],
|
|
552
|
+
jnp.ones((1, 10, model.config.hidden_size)), [])
|
|
553
|
+
model.language_model.return_value = mock_lm_output
|
|
554
|
+
|
|
555
|
+
new_kvs, x, aux_hidden_states = model(kv_caches, input_ids, attn_meta)
|
|
556
|
+
model.language_model.assert_called_once_with(
|
|
557
|
+
kv_caches=kv_caches,
|
|
558
|
+
input_ids=input_ids,
|
|
559
|
+
attention_metadata=attn_meta,
|
|
560
|
+
inputs_embeds=None)
|
|
561
|
+
assert len(new_kvs) == 1
|
|
562
|
+
assert x.shape == (1, 10, model.config.hidden_size)
|
|
563
|
+
assert len(aux_hidden_states) == 0
|
|
564
|
+
|
|
565
|
+
def test_compute_logits(self, model: Qwen2_5_VLForConditionalGeneration,
|
|
566
|
+
rng: PRNGKey):
|
|
567
|
+
hidden_states = jnp.ones((1, 10, model.config.hidden_size))
|
|
568
|
+
mock_logits = jnp.ones((1, 10, model.config.vocab_size))
|
|
569
|
+
model.language_model.compute_logits.return_value = mock_logits
|
|
570
|
+
|
|
571
|
+
logits = model.compute_logits(hidden_states)
|
|
572
|
+
np.testing.assert_array_equal(logits, mock_logits)
|
|
573
|
+
model.language_model.compute_logits.assert_called_once_with(
|
|
574
|
+
hidden_states)
|
|
575
|
+
|
|
576
|
+
@patch('tpu_inference.models.jax.qwen2_5_vl.load_hf_weights')
|
|
577
|
+
def test_load_weights(self, mock_load_weights: MagicMock,
|
|
578
|
+
model: Qwen2_5_VLForConditionalGeneration,
|
|
579
|
+
mock_vllm_config: MockVllmConfig, rng: PRNGKey,
|
|
580
|
+
mesh: Mesh):
|
|
581
|
+
model.load_weights(rng)
|
|
582
|
+
mock_load_weights.assert_called_once()
|
|
583
|
+
kwargs = mock_load_weights.call_args.kwargs
|
|
584
|
+
assert kwargs['vllm_config'] == mock_vllm_config
|
|
585
|
+
assert kwargs['model'] is model
|
|
586
|
+
assert "model.embed_tokens" in kwargs['metadata_map'].name_map
|
|
587
|
+
assert "lm_head" in kwargs[
|
|
588
|
+
'metadata_map'].name_map # Should be present when not tied
|
|
589
|
+
assert kwargs['mesh'] is mesh
|
|
590
|
+
assert isinstance(model.rng, nnx.Rngs)
|
|
591
|
+
assert model.language_model.rng is model.rng
|
|
592
|
+
|
|
593
|
+
@patch('tpu_inference.models.jax.qwen2_5_vl.load_hf_weights')
|
|
594
|
+
def test_load_weights_tied(self, mock_load_weights: MagicMock,
|
|
595
|
+
rng: PRNGKey, mesh: Mesh):
|
|
596
|
+
mock_vllm_config_tied = MockVllmConfig(tie_word_embeddings=True)
|
|
597
|
+
with patch('tpu_inference.models.jax.qwen2_5_vl.Qwen2_5_VisionTransformer', autospec=True), \
|
|
598
|
+
patch('tpu_inference.models.jax.qwen2_5_vl.Qwen2ForCausalLM', autospec=True):
|
|
599
|
+
model = Qwen2_5_VLForConditionalGeneration(mock_vllm_config_tied,
|
|
600
|
+
rng, mesh)
|
|
601
|
+
|
|
602
|
+
model.load_weights(rng)
|
|
603
|
+
mock_load_weights.assert_called_once()
|
|
604
|
+
kwargs = mock_load_weights.call_args.kwargs
|
|
605
|
+
assert "lm_head" not in kwargs['metadata_map'].name_map
|