tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +317 -34
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +26 -6
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +110 -12
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +2 -45
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +15 -10
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +92 -8
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +25 -4
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +25 -12
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +32 -9
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +101 -494
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
- tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +112 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +18 -5
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +179 -51
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
- tpu_inference/models/jax/utils/weight_utils.py +234 -155
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +51 -72
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +180 -80
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +55 -33
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +54 -2
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +16 -3
- tpu_inference/runner/tpu_runner.py +124 -61
- tpu_inference/runner/utils.py +2 -2
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +84 -22
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +72 -44
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +66 -52
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -186
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,429 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from unittest.mock import MagicMock, patch
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import numpy as np
|
|
20
|
+
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
|
21
|
+
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
|
22
|
+
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
|
23
|
+
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
|
24
|
+
MultiModalFeatureSpec, MultiModalFieldElem,
|
|
25
|
+
MultiModalKwargsItem, PlaceholderRange)
|
|
26
|
+
from vllm.sampling_params import SamplingType
|
|
27
|
+
from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
|
|
28
|
+
|
|
29
|
+
from tpu_inference.runner.input_batch import CachedRequestState
|
|
30
|
+
from tpu_inference.runner.tpu_runner import TPUModelRunner
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TestMultiModalManager:
|
|
34
|
+
|
|
35
|
+
def setup_method(self):
|
|
36
|
+
# Mock JAX dependencies
|
|
37
|
+
self.mock_devices = [MagicMock(coords=i) for i in range(1)]
|
|
38
|
+
device_array = np.array(jax.devices()[:1]).reshape(1, 1, 1, 1)
|
|
39
|
+
self.mock_mesh = jax.make_mesh(device_array.shape,
|
|
40
|
+
('data', 'attn_dp', 'expert', 'model'))
|
|
41
|
+
self.mock_rng_key = MagicMock()
|
|
42
|
+
|
|
43
|
+
with patch('jax.devices', return_value=self.mock_devices), \
|
|
44
|
+
patch('jax.make_mesh', return_value=self.mock_mesh), \
|
|
45
|
+
patch('jax.random.key', return_value=self.mock_rng_key), \
|
|
46
|
+
patch('tpu_inference.runner.tpu_runner.get_model', return_value=MagicMock()), \
|
|
47
|
+
patch('tpu_inference.runner.tpu_runner.make_optimized_mesh', return_value=self.mock_mesh):
|
|
48
|
+
|
|
49
|
+
model_config = ModelConfig(tokenizer_mode="auto",
|
|
50
|
+
trust_remote_code=False,
|
|
51
|
+
seed=0,
|
|
52
|
+
dtype='bfloat16')
|
|
53
|
+
cache_config = CacheConfig(
|
|
54
|
+
block_size=16,
|
|
55
|
+
gpu_memory_utilization=0.9,
|
|
56
|
+
swap_space=4,
|
|
57
|
+
cache_dtype="auto",
|
|
58
|
+
)
|
|
59
|
+
scheduler_config = SchedulerConfig(max_num_seqs=16,
|
|
60
|
+
max_model_len=1024,
|
|
61
|
+
is_encoder_decoder=False)
|
|
62
|
+
parallel_config = ParallelConfig(
|
|
63
|
+
pipeline_parallel_size=1,
|
|
64
|
+
tensor_parallel_size=1,
|
|
65
|
+
worker_use_ray=False,
|
|
66
|
+
)
|
|
67
|
+
speculative_config = SpeculativeConfig(
|
|
68
|
+
model='ngram',
|
|
69
|
+
num_speculative_tokens=5,
|
|
70
|
+
prompt_lookup_max=4,
|
|
71
|
+
)
|
|
72
|
+
vllm_config = VllmConfig(
|
|
73
|
+
model_config=model_config,
|
|
74
|
+
cache_config=cache_config,
|
|
75
|
+
scheduler_config=scheduler_config,
|
|
76
|
+
parallel_config=parallel_config,
|
|
77
|
+
speculative_config=speculative_config,
|
|
78
|
+
observability_config={},
|
|
79
|
+
additional_config={},
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
self.runner = TPUModelRunner(vllm_config,
|
|
83
|
+
devices=self.mock_devices)
|
|
84
|
+
|
|
85
|
+
def test_execute_mm_encoder_single_image(self):
|
|
86
|
+
import torch
|
|
87
|
+
"""Tests _execute_mm_encoder with a single request and a single image."""
|
|
88
|
+
# 1. ===== Setup =====
|
|
89
|
+
self.runner.is_multimodal_model = True
|
|
90
|
+
self.mock_get_mm_embed_fn = MagicMock()
|
|
91
|
+
self.runner.embed_multimodal_fn = self.mock_get_mm_embed_fn
|
|
92
|
+
|
|
93
|
+
self.runner.state = MagicMock()
|
|
94
|
+
# Mock scheduler output
|
|
95
|
+
mock_scheduler_output = MagicMock(spec=VllmSchedulerOutput)
|
|
96
|
+
mock_scheduler_output.scheduled_encoder_inputs = {"req-1": [0]}
|
|
97
|
+
|
|
98
|
+
# Mock request state
|
|
99
|
+
dummy_pixel_values = torch.randn(3, 224, 224, dtype=torch.bfloat16)
|
|
100
|
+
dummy_grid_thw = torch.tensor([[1, 1, 1]], dtype=torch.int64)
|
|
101
|
+
mm_item = MultiModalKwargsItem.from_elems([
|
|
102
|
+
MultiModalFieldElem("image", "pixel_values", dummy_pixel_values,
|
|
103
|
+
MultiModalBatchedField()),
|
|
104
|
+
MultiModalFieldElem("image", "image_grid_thw", dummy_grid_thw,
|
|
105
|
+
MultiModalBatchedField())
|
|
106
|
+
])
|
|
107
|
+
|
|
108
|
+
req_state = CachedRequestState(
|
|
109
|
+
req_id="req-1",
|
|
110
|
+
prompt_token_ids=[1, 2, 3],
|
|
111
|
+
output_token_ids=[],
|
|
112
|
+
sampling_params=MagicMock(),
|
|
113
|
+
block_ids=(),
|
|
114
|
+
num_computed_tokens=0,
|
|
115
|
+
mm_features=[
|
|
116
|
+
MultiModalFeatureSpec(data=mm_item,
|
|
117
|
+
identifier="req-1",
|
|
118
|
+
modality="image",
|
|
119
|
+
mm_position=PlaceholderRange(offset=0,
|
|
120
|
+
length=1))
|
|
121
|
+
],
|
|
122
|
+
lora_request=None,
|
|
123
|
+
pooling_params=None,
|
|
124
|
+
generator=None,
|
|
125
|
+
)
|
|
126
|
+
self.runner.requests = {"req-1": req_state}
|
|
127
|
+
|
|
128
|
+
# Mock the return value of the multimodal encoder
|
|
129
|
+
dummy_embedding = jnp.ones((10, 128), dtype=jnp.bfloat16)
|
|
130
|
+
self.mock_get_mm_embed_fn.return_value = (dummy_embedding, )
|
|
131
|
+
|
|
132
|
+
# 2. ===== Act =====
|
|
133
|
+
self.runner.mm_manager.execute_mm_encoder(mock_scheduler_output)
|
|
134
|
+
|
|
135
|
+
# 3. ===== Assert =====
|
|
136
|
+
# Check if encoder_cache is populated correctly
|
|
137
|
+
assert "req-1" in self.runner.encoder_cache
|
|
138
|
+
cached_embedding = self.runner.encoder_cache["req-1"]
|
|
139
|
+
np.testing.assert_array_equal(np.asarray(cached_embedding),
|
|
140
|
+
np.asarray(dummy_embedding))
|
|
141
|
+
|
|
142
|
+
# Check if embed_multimodal_fn was called with correct args
|
|
143
|
+
self.mock_get_mm_embed_fn.assert_called_once()
|
|
144
|
+
call_args = self.mock_get_mm_embed_fn.call_args
|
|
145
|
+
|
|
146
|
+
# Positional args: (state, image_grid_thw)
|
|
147
|
+
state_arg, grid_arg = call_args.args
|
|
148
|
+
# Keyword args: **batched_mm_inputs
|
|
149
|
+
kwargs_arg = call_args.kwargs
|
|
150
|
+
|
|
151
|
+
assert state_arg == self.runner.state
|
|
152
|
+
assert grid_arg == ((1, 1, 1), )
|
|
153
|
+
assert "pixel_values" in kwargs_arg
|
|
154
|
+
|
|
155
|
+
# Verify the pixel values tensor passed to the mock
|
|
156
|
+
passed_pixel_values = kwargs_arg['pixel_values']
|
|
157
|
+
assert isinstance(passed_pixel_values, np.ndarray)
|
|
158
|
+
assert passed_pixel_values.dtype == jnp.bfloat16
|
|
159
|
+
|
|
160
|
+
# Convert torch tensor for comparison
|
|
161
|
+
expected_pixel_values = dummy_pixel_values.unsqueeze(0).to(
|
|
162
|
+
torch.float32).numpy().astype(jnp.bfloat16)
|
|
163
|
+
np.testing.assert_array_equal(np.asarray(passed_pixel_values),
|
|
164
|
+
expected_pixel_values)
|
|
165
|
+
|
|
166
|
+
def test_execute_mm_encoder_multiple_images(self):
|
|
167
|
+
import torch
|
|
168
|
+
"""Tests _execute_mm_encoder with multiple requests and images."""
|
|
169
|
+
# 1. ===== Setup =====
|
|
170
|
+
self.runner.is_multimodal_model = True
|
|
171
|
+
self.mock_get_mm_embed_fn = MagicMock()
|
|
172
|
+
self.runner.embed_multimodal_fn = self.mock_get_mm_embed_fn
|
|
173
|
+
|
|
174
|
+
self.runner.state = MagicMock()
|
|
175
|
+
# Mock scheduler output for two requests
|
|
176
|
+
mock_scheduler_output = MagicMock(spec=VllmSchedulerOutput)
|
|
177
|
+
mock_scheduler_output.scheduled_encoder_inputs = {
|
|
178
|
+
"req-1": [0],
|
|
179
|
+
"req-2": [0]
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
# Mock request states
|
|
183
|
+
px_1 = torch.randn(3, 224, 224, dtype=torch.bfloat16)
|
|
184
|
+
grid_1 = torch.tensor([[1, 1, 1]], dtype=torch.int64)
|
|
185
|
+
|
|
186
|
+
mm_item_1 = MultiModalKwargsItem.from_elems([
|
|
187
|
+
MultiModalFieldElem("image", "pixel_values", px_1,
|
|
188
|
+
MultiModalBatchedField()),
|
|
189
|
+
MultiModalFieldElem("image", "image_grid_thw", grid_1,
|
|
190
|
+
MultiModalBatchedField())
|
|
191
|
+
])
|
|
192
|
+
|
|
193
|
+
req_state_1 = CachedRequestState(
|
|
194
|
+
req_id="req-1",
|
|
195
|
+
prompt_token_ids=[],
|
|
196
|
+
output_token_ids=[],
|
|
197
|
+
sampling_params=MagicMock(),
|
|
198
|
+
block_ids=(),
|
|
199
|
+
num_computed_tokens=0,
|
|
200
|
+
mm_features=[
|
|
201
|
+
MultiModalFeatureSpec(data=mm_item_1,
|
|
202
|
+
identifier="req-1",
|
|
203
|
+
modality="image",
|
|
204
|
+
mm_position=PlaceholderRange(offset=0,
|
|
205
|
+
length=1))
|
|
206
|
+
],
|
|
207
|
+
lora_request=None,
|
|
208
|
+
pooling_params=None,
|
|
209
|
+
generator=None)
|
|
210
|
+
|
|
211
|
+
px_2 = torch.randn(3, 224, 224, dtype=torch.bfloat16)
|
|
212
|
+
grid_2 = torch.tensor([[1, 2, 2]], dtype=torch.int64)
|
|
213
|
+
mm_item_2 = MultiModalKwargsItem.from_elems([
|
|
214
|
+
MultiModalFieldElem("image", "pixel_values", px_2,
|
|
215
|
+
MultiModalBatchedField()),
|
|
216
|
+
MultiModalFieldElem("image", "image_grid_thw", grid_2,
|
|
217
|
+
MultiModalBatchedField())
|
|
218
|
+
])
|
|
219
|
+
|
|
220
|
+
req_state_2 = CachedRequestState(
|
|
221
|
+
req_id="req-2",
|
|
222
|
+
prompt_token_ids=[],
|
|
223
|
+
output_token_ids=[],
|
|
224
|
+
sampling_params=MagicMock(),
|
|
225
|
+
block_ids=(),
|
|
226
|
+
num_computed_tokens=0,
|
|
227
|
+
mm_features=[
|
|
228
|
+
MultiModalFeatureSpec(data=mm_item_2,
|
|
229
|
+
identifier="req-2",
|
|
230
|
+
modality="image",
|
|
231
|
+
mm_position=PlaceholderRange(offset=0,
|
|
232
|
+
length=1))
|
|
233
|
+
],
|
|
234
|
+
lora_request=None,
|
|
235
|
+
pooling_params=None,
|
|
236
|
+
generator=None)
|
|
237
|
+
|
|
238
|
+
self.runner.requests = {"req-1": req_state_1, "req-2": req_state_2}
|
|
239
|
+
|
|
240
|
+
emb_1 = jnp.ones((10, 128), dtype=jnp.bfloat16)
|
|
241
|
+
emb_2 = jnp.ones((20, 128), dtype=jnp.bfloat16) * 2
|
|
242
|
+
self.mock_get_mm_embed_fn.return_value = (emb_1, emb_2)
|
|
243
|
+
|
|
244
|
+
# 2. ===== Act =====
|
|
245
|
+
self.runner.mm_manager.execute_mm_encoder(mock_scheduler_output)
|
|
246
|
+
|
|
247
|
+
# 3. ===== Assert =====
|
|
248
|
+
assert "req-1" in self.runner.encoder_cache
|
|
249
|
+
np.testing.assert_array_equal(
|
|
250
|
+
np.asarray(self.runner.encoder_cache["req-1"]), np.asarray(emb_1))
|
|
251
|
+
assert "req-2" in self.runner.encoder_cache
|
|
252
|
+
np.testing.assert_array_equal(
|
|
253
|
+
np.asarray(self.runner.encoder_cache["req-2"]), np.asarray(emb_2))
|
|
254
|
+
|
|
255
|
+
self.mock_get_mm_embed_fn.assert_called_once()
|
|
256
|
+
call_args = self.mock_get_mm_embed_fn.call_args
|
|
257
|
+
|
|
258
|
+
state_arg, grid_arg = call_args.args
|
|
259
|
+
kwargs_arg = call_args.kwargs
|
|
260
|
+
|
|
261
|
+
assert state_arg == self.runner.state
|
|
262
|
+
assert grid_arg == ((1, 1, 1), (1, 2, 2))
|
|
263
|
+
assert "pixel_values" in kwargs_arg
|
|
264
|
+
|
|
265
|
+
passed_pixel_values = kwargs_arg['pixel_values']
|
|
266
|
+
assert passed_pixel_values.shape == (2, 3, 224, 224)
|
|
267
|
+
|
|
268
|
+
expected_pixel_values = torch.stack([px_1, px_2], dim=0).to(
|
|
269
|
+
torch.float32).numpy().astype(jnp.bfloat16)
|
|
270
|
+
np.testing.assert_array_equal(np.asarray(passed_pixel_values),
|
|
271
|
+
expected_pixel_values)
|
|
272
|
+
|
|
273
|
+
def test_gather_mm_embeddings_chunked_prefill(self):
|
|
274
|
+
"""Tests _gather_mm_embeddings with chunked prefill scenarios."""
|
|
275
|
+
# 1. ===== Setup =====
|
|
276
|
+
self.runner.is_multimodal_model = True
|
|
277
|
+
req_id = "req-1"
|
|
278
|
+
|
|
279
|
+
# Mock encoder output
|
|
280
|
+
encoder_embedding = jnp.arange(56 * 128, dtype=jnp.bfloat16).reshape(
|
|
281
|
+
(56, 128))
|
|
282
|
+
self.runner.encoder_cache = {req_id: encoder_embedding}
|
|
283
|
+
|
|
284
|
+
mock_sampling_params = MagicMock()
|
|
285
|
+
mock_sampling_params.sampling_type = SamplingType.GREEDY
|
|
286
|
+
mock_sampling_params.top_k = -1
|
|
287
|
+
mock_sampling_params.top_p = 1.0
|
|
288
|
+
mock_sampling_params.temperature = 0.0
|
|
289
|
+
mock_sampling_params.min_tokens = 0
|
|
290
|
+
mock_sampling_params.logprobs = None
|
|
291
|
+
mock_sampling_params.logit_bias = None
|
|
292
|
+
mock_sampling_params.allowed_token_ids = set()
|
|
293
|
+
mock_sampling_params.bad_words_token_ids = None
|
|
294
|
+
mock_sampling_params.all_stop_token_ids = set()
|
|
295
|
+
|
|
296
|
+
# Mock request state
|
|
297
|
+
req_state = CachedRequestState(
|
|
298
|
+
req_id=req_id,
|
|
299
|
+
prompt_token_ids=list(range(100)),
|
|
300
|
+
output_token_ids=[],
|
|
301
|
+
sampling_params=mock_sampling_params,
|
|
302
|
+
block_ids=([], ),
|
|
303
|
+
num_computed_tokens=0, # This will be updated per step
|
|
304
|
+
mm_features=[
|
|
305
|
+
MultiModalFeatureSpec(data=None,
|
|
306
|
+
identifier=req_id,
|
|
307
|
+
modality="image",
|
|
308
|
+
mm_position=PlaceholderRange(offset=10,
|
|
309
|
+
length=56))
|
|
310
|
+
],
|
|
311
|
+
lora_request=None,
|
|
312
|
+
pooling_params=None,
|
|
313
|
+
generator=None,
|
|
314
|
+
)
|
|
315
|
+
self.runner.requests = {req_id: req_state}
|
|
316
|
+
self.runner.input_batch.add_request(req_state)
|
|
317
|
+
|
|
318
|
+
# 2. ===== Act & Assert =====
|
|
319
|
+
|
|
320
|
+
# ----- Step 1: First chunk of prefill -----
|
|
321
|
+
req_state.num_computed_tokens = 0
|
|
322
|
+
mock_scheduler_output_1 = MagicMock(spec=VllmSchedulerOutput)
|
|
323
|
+
mock_scheduler_output_1.num_scheduled_tokens = {req_id: 20}
|
|
324
|
+
|
|
325
|
+
gathered_embeds_1 = self.runner.mm_manager.gather_mm_embeddings(
|
|
326
|
+
mock_scheduler_output_1, target_pad_len=10)
|
|
327
|
+
|
|
328
|
+
expected_embeds_1 = encoder_embedding[0:10]
|
|
329
|
+
assert gathered_embeds_1.shape == expected_embeds_1.shape
|
|
330
|
+
np.testing.assert_array_equal(np.asarray(gathered_embeds_1),
|
|
331
|
+
np.asarray(expected_embeds_1))
|
|
332
|
+
|
|
333
|
+
# ----- Step 2: Middle chunk of prefill -----
|
|
334
|
+
req_state.num_computed_tokens = 20
|
|
335
|
+
mock_scheduler_output_2 = MagicMock(spec=VllmSchedulerOutput)
|
|
336
|
+
mock_scheduler_output_2.num_scheduled_tokens = {req_id: 30}
|
|
337
|
+
|
|
338
|
+
gathered_embeds_2 = self.runner.mm_manager.gather_mm_embeddings(
|
|
339
|
+
mock_scheduler_output_2, target_pad_len=30)
|
|
340
|
+
|
|
341
|
+
expected_embeds_2 = encoder_embedding[10:40]
|
|
342
|
+
assert gathered_embeds_2.shape == expected_embeds_2.shape
|
|
343
|
+
np.testing.assert_array_equal(np.asarray(gathered_embeds_2),
|
|
344
|
+
np.asarray(expected_embeds_2))
|
|
345
|
+
|
|
346
|
+
# ----- Step 3: Last chunk of prefill -----
|
|
347
|
+
req_state.num_computed_tokens = 50
|
|
348
|
+
mock_scheduler_output_3 = MagicMock(spec=VllmSchedulerOutput)
|
|
349
|
+
mock_scheduler_output_3.num_scheduled_tokens = {req_id: 30}
|
|
350
|
+
|
|
351
|
+
gathered_embeds_3 = self.runner.mm_manager.gather_mm_embeddings(
|
|
352
|
+
mock_scheduler_output_3, target_pad_len=16)
|
|
353
|
+
|
|
354
|
+
expected_embeds_3 = encoder_embedding[40:56]
|
|
355
|
+
assert gathered_embeds_3.shape == expected_embeds_3.shape
|
|
356
|
+
np.testing.assert_array_equal(np.asarray(gathered_embeds_3),
|
|
357
|
+
np.asarray(expected_embeds_3))
|
|
358
|
+
|
|
359
|
+
def test_calc_mrope_positions(self):
|
|
360
|
+
"""Tests the calculation of M-RoPE positions for mixed prompt/completion."""
|
|
361
|
+
# 1. ===== Setup =====
|
|
362
|
+
self.runner.uses_mrope = True
|
|
363
|
+
req_id = "req-1"
|
|
364
|
+
prompt_len = 20
|
|
365
|
+
num_computed = 15
|
|
366
|
+
num_scheduled = 10
|
|
367
|
+
mrope_delta = 100
|
|
368
|
+
|
|
369
|
+
# Mock request state with pre-computed mrope positions for the prompt
|
|
370
|
+
mock_mrope_positions = np.arange(3 * prompt_len,
|
|
371
|
+
dtype=np.int64).reshape(
|
|
372
|
+
3, prompt_len)
|
|
373
|
+
mock_sampling_params = MagicMock()
|
|
374
|
+
mock_sampling_params.sampling_type = SamplingType.GREEDY
|
|
375
|
+
mock_sampling_params.top_k = -1
|
|
376
|
+
mock_sampling_params.top_p = 1.0
|
|
377
|
+
mock_sampling_params.temperature = 0.0
|
|
378
|
+
mock_sampling_params.min_tokens = 0
|
|
379
|
+
mock_sampling_params.logprobs = None
|
|
380
|
+
mock_sampling_params.logit_bias = None
|
|
381
|
+
mock_sampling_params.allowed_token_ids = set()
|
|
382
|
+
mock_sampling_params.bad_words_token_ids = None
|
|
383
|
+
mock_sampling_params.all_stop_token_ids = set()
|
|
384
|
+
|
|
385
|
+
req_state = CachedRequestState(
|
|
386
|
+
req_id=req_id,
|
|
387
|
+
prompt_token_ids=list(range(prompt_len)),
|
|
388
|
+
output_token_ids=[],
|
|
389
|
+
sampling_params=mock_sampling_params,
|
|
390
|
+
block_ids=([], ),
|
|
391
|
+
num_computed_tokens=num_computed,
|
|
392
|
+
mm_features=[],
|
|
393
|
+
lora_request=None,
|
|
394
|
+
pooling_params=None,
|
|
395
|
+
generator=None,
|
|
396
|
+
mrope_positions=mock_mrope_positions,
|
|
397
|
+
mrope_position_delta=mrope_delta,
|
|
398
|
+
)
|
|
399
|
+
self.runner.requests = {req_id: req_state}
|
|
400
|
+
self.runner.input_batch.add_request(req_state)
|
|
401
|
+
# Manually set num_computed_tokens in the batch as add_request sets it to 0
|
|
402
|
+
self.runner.input_batch.num_computed_tokens_cpu[0] = num_computed
|
|
403
|
+
|
|
404
|
+
# Mock scheduler output
|
|
405
|
+
mock_scheduler_output = MagicMock(spec=VllmSchedulerOutput)
|
|
406
|
+
mock_scheduler_output.num_scheduled_tokens = {req_id: num_scheduled}
|
|
407
|
+
|
|
408
|
+
# Patch the static method that computes completion positions
|
|
409
|
+
with patch.object(MRotaryEmbedding,
|
|
410
|
+
"get_next_input_positions_tensor") as mock_get_next:
|
|
411
|
+
# 2. ===== Act =====
|
|
412
|
+
self.runner.mm_manager.calc_mrope_positions(mock_scheduler_output)
|
|
413
|
+
|
|
414
|
+
# 3. ===== Assert =====
|
|
415
|
+
# The first 5 positions should be copied from the pre-computed prompt positions
|
|
416
|
+
expected_prompt_part = mock_mrope_positions[:, 15:20]
|
|
417
|
+
actual_prompt_part = self.runner.mrope_positions_cpu[:, 0:5]
|
|
418
|
+
np.testing.assert_array_equal(actual_prompt_part,
|
|
419
|
+
expected_prompt_part)
|
|
420
|
+
|
|
421
|
+
# The next 5 positions should be computed on-the-fly
|
|
422
|
+
mock_get_next.assert_called_once()
|
|
423
|
+
call_kwargs = mock_get_next.call_args.kwargs
|
|
424
|
+
np.testing.assert_array_equal(call_kwargs["out"],
|
|
425
|
+
self.runner.mrope_positions_cpu)
|
|
426
|
+
assert call_kwargs["out_offset"] == 5
|
|
427
|
+
assert call_kwargs["mrope_position_delta"] == mrope_delta
|
|
428
|
+
assert call_kwargs["context_len"] == prompt_len
|
|
429
|
+
assert call_kwargs["num_new_tokens"] == 5
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import unittest
|
|
16
|
+
from unittest.mock import MagicMock
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
from tpu_inference.runner.persistent_batch_manager import \
|
|
21
|
+
PersistentBatchManager
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TestPersistentBatchManager(unittest.TestCase):
|
|
25
|
+
|
|
26
|
+
def test_update_states_pp_non_last_rank(self):
|
|
27
|
+
"""
|
|
28
|
+
the current rank is not the last rank.
|
|
29
|
+
|
|
30
|
+
This test verifies that when new tokens are received from the scheduler,
|
|
31
|
+
the internal state of the PersistentBatchManager (including request
|
|
32
|
+
states and the input batch) is correctly updated.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
req_id = 101
|
|
36
|
+
initial_output_tokens = [10, 20]
|
|
37
|
+
|
|
38
|
+
req_state = MagicMock()
|
|
39
|
+
req_state.num_tokens = 2
|
|
40
|
+
req_state.output_token_ids = list(initial_output_tokens)
|
|
41
|
+
|
|
42
|
+
requests = {req_id: req_state}
|
|
43
|
+
|
|
44
|
+
input_batch = MagicMock()
|
|
45
|
+
input_batch.req_id_to_index = {req_id: 0}
|
|
46
|
+
input_batch.num_prompt_tokens = np.array([2], dtype=np.int32)
|
|
47
|
+
input_batch.token_ids_cpu = np.zeros((1, 10), dtype=np.int32)
|
|
48
|
+
input_batch.num_tokens = np.array([2], dtype=np.int32)
|
|
49
|
+
input_batch.num_tokens_no_spec = np.array([2], dtype=np.int32)
|
|
50
|
+
input_batch.num_reqs = 1
|
|
51
|
+
|
|
52
|
+
encoder_cache = MagicMock()
|
|
53
|
+
model_config = MagicMock()
|
|
54
|
+
|
|
55
|
+
manager = PersistentBatchManager(requests,
|
|
56
|
+
input_batch,
|
|
57
|
+
encoder_cache,
|
|
58
|
+
False,
|
|
59
|
+
model_config,
|
|
60
|
+
is_last_rank=False)
|
|
61
|
+
|
|
62
|
+
scheduler_output = MagicMock()
|
|
63
|
+
req_data = MagicMock()
|
|
64
|
+
req_data.req_ids = [req_id]
|
|
65
|
+
req_data.num_computed_tokens = [2]
|
|
66
|
+
new_token_id = [30]
|
|
67
|
+
req_data.new_token_ids = [new_token_id]
|
|
68
|
+
req_data.new_block_ids = [None]
|
|
69
|
+
req_data.resumed_from_preemption = [False]
|
|
70
|
+
req_data.num_output_tokens = [len(initial_output_tokens) + 1]
|
|
71
|
+
scheduler_output.scheduled_cached_reqs = req_data
|
|
72
|
+
scheduler_output.scheduled_spec_decode_tokens = {}
|
|
73
|
+
|
|
74
|
+
manager.update_states(scheduler_output, None)
|
|
75
|
+
|
|
76
|
+
expected_output_token_ids = initial_output_tokens + new_token_id
|
|
77
|
+
self.assertEqual(req_state.output_token_ids, expected_output_token_ids)
|
|
78
|
+
|
|
79
|
+
np.testing.assert_array_equal(
|
|
80
|
+
manager.input_batch.token_ids_cpu[0, 2:3],
|
|
81
|
+
np.array(new_token_id, dtype=np.int32))
|
|
82
|
+
|
|
83
|
+
self.assertEqual(manager.input_batch.num_tokens[0], 3)
|
|
84
|
+
self.assertEqual(manager.input_batch.num_tokens_no_spec[0], 3)
|