tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +14 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +25 -8
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +14 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +20 -3
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +20 -26
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +22 -3
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +100 -455
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
- tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +37 -16
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +113 -124
- tpu_inference/models/jax/gpt_oss.py +23 -7
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
- tpu_inference/models/jax/utils/weight_utils.py +32 -1
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +27 -29
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +69 -35
- tpu_inference/runner/kv_cache.py +14 -0
- tpu_inference/runner/kv_cache_manager.py +15 -2
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +30 -10
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +31 -30
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +23 -7
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -208
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1099 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from contextlib import nullcontext
|
|
16
|
+
from unittest.mock import MagicMock, patch
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
import pytest
|
|
20
|
+
|
|
21
|
+
from tpu_inference.runner.tpu_runner import TPUModelRunner
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TestTPUJaxRunnerDPInputsLightweight:
|
|
25
|
+
|
|
26
|
+
def setup_method(self):
|
|
27
|
+
self.runner = MagicMock()
|
|
28
|
+
|
|
29
|
+
# Basic DP configuration
|
|
30
|
+
self.runner.dp_size = 2
|
|
31
|
+
self.runner.max_num_tokens = 64
|
|
32
|
+
self.runner.max_num_reqs = 8
|
|
33
|
+
self.runner.max_num_blocks_per_req = 8
|
|
34
|
+
self.runner.num_tokens_paddings = [16, 32, 64]
|
|
35
|
+
|
|
36
|
+
# Mock input batch - adjust num_reqs to match test data
|
|
37
|
+
self.runner.input_batch = MagicMock()
|
|
38
|
+
self.runner.input_batch.num_reqs = 2
|
|
39
|
+
self.runner.input_batch.req_ids = ["req1", "req2", "req3", "req4"]
|
|
40
|
+
self.runner.input_batch.req_id_to_index = {
|
|
41
|
+
"req1": 0,
|
|
42
|
+
"req2": 1,
|
|
43
|
+
"req3": 2,
|
|
44
|
+
"req4": 3
|
|
45
|
+
}
|
|
46
|
+
self.runner.input_batch.num_computed_tokens_cpu = np.array(
|
|
47
|
+
[10, 20, 5, 15])
|
|
48
|
+
self.runner.input_batch.token_ids_cpu = np.random.randint(
|
|
49
|
+
0, 1000, (8, 64), dtype=np.int32)
|
|
50
|
+
|
|
51
|
+
# Mock block table
|
|
52
|
+
mock_block_table = MagicMock()
|
|
53
|
+
mock_block_table.get_cpu_tensor.return_value = np.arange(32).reshape(
|
|
54
|
+
4, 8)
|
|
55
|
+
self.runner.input_batch.block_table = [mock_block_table]
|
|
56
|
+
|
|
57
|
+
# Initialize CPU arrays that the method modifies
|
|
58
|
+
self.runner.input_ids_cpu = np.zeros(64, dtype=np.int32)
|
|
59
|
+
self.runner.positions_cpu = np.zeros(64, dtype=np.int32)
|
|
60
|
+
self.runner.query_start_loc_cpu = np.zeros(10, dtype=np.int32)
|
|
61
|
+
self.runner.seq_lens_cpu = np.zeros(8, dtype=np.int32)
|
|
62
|
+
self.runner.logits_indices_cpu = np.zeros(8, dtype=np.int32)
|
|
63
|
+
self.runner.block_tables_cpu = [np.zeros((8, 8), dtype=np.int32)]
|
|
64
|
+
self.runner.arange_cpu = np.arange(64, dtype=np.int64)
|
|
65
|
+
|
|
66
|
+
# mock kv cache group
|
|
67
|
+
mock_kv_cache_config = MagicMock()
|
|
68
|
+
mock_kv_cache_group = MagicMock()
|
|
69
|
+
mock_kv_cache_config.kv_cache_groups = [mock_kv_cache_group]
|
|
70
|
+
self.runner.kv_cache_config = mock_kv_cache_config
|
|
71
|
+
self.runner.use_hybrid_kvcache = False
|
|
72
|
+
|
|
73
|
+
# Mock scheduler config for async scheduling
|
|
74
|
+
self.runner.scheduler_config = MagicMock()
|
|
75
|
+
self.runner.scheduler_config.async_scheduling = False # Default to False for most tests
|
|
76
|
+
self.runner._pre_async_results = None # Default to None for most tests
|
|
77
|
+
|
|
78
|
+
# Bind the actual methods to our mock
|
|
79
|
+
self.runner._prepare_inputs_dp = TPUModelRunner._prepare_inputs_dp.__get__(
|
|
80
|
+
self.runner)
|
|
81
|
+
self.runner._prepare_dp_input_metadata = TPUModelRunner._prepare_dp_input_metadata.__get__(
|
|
82
|
+
self.runner)
|
|
83
|
+
self.runner._prepare_async_token_substitution_indices_dp = TPUModelRunner._prepare_async_token_substitution_indices_dp.__get__(
|
|
84
|
+
self.runner)
|
|
85
|
+
|
|
86
|
+
def _create_mock_scheduler_output(self,
|
|
87
|
+
num_scheduled_tokens_dict,
|
|
88
|
+
assigned_dp_ranks,
|
|
89
|
+
scheduled_spec_decode_tokens=None):
|
|
90
|
+
"""Create a minimal mock scheduler output."""
|
|
91
|
+
mock_output = MagicMock()
|
|
92
|
+
mock_output.num_scheduled_tokens = num_scheduled_tokens_dict
|
|
93
|
+
mock_output.assigned_dp_rank = assigned_dp_ranks
|
|
94
|
+
mock_output.total_num_scheduled_tokens = sum(
|
|
95
|
+
num_scheduled_tokens_dict.values())
|
|
96
|
+
mock_output.scheduled_spec_decode_tokens = scheduled_spec_decode_tokens or {}
|
|
97
|
+
mock_output.grammar_bitmask = None
|
|
98
|
+
return mock_output
|
|
99
|
+
|
|
100
|
+
def _create_mock_hybrid_kv_cache_config(self):
|
|
101
|
+
mock_kv_cache_config = MagicMock()
|
|
102
|
+
mock_kv_cache_group1 = MagicMock()
|
|
103
|
+
mock_kv_cache_group1.layer_names = [f'layer.{i}' for i in range(10)]
|
|
104
|
+
mock_kv_cache_group2 = MagicMock()
|
|
105
|
+
mock_kv_cache_group2.layer_names = [
|
|
106
|
+
f'layer.{i}' for i in range(10, 20)
|
|
107
|
+
]
|
|
108
|
+
mock_kv_cache_config.kv_cache_groups = [
|
|
109
|
+
mock_kv_cache_group1, mock_kv_cache_group2
|
|
110
|
+
]
|
|
111
|
+
self.runner.kv_cache_config = mock_kv_cache_config
|
|
112
|
+
self.runner.use_hybrid_kvcache = True
|
|
113
|
+
|
|
114
|
+
@patch('tpu_inference.runner.tpu_runner.NamedSharding')
|
|
115
|
+
@patch('tpu_inference.runner.tpu_runner.runner_utils')
|
|
116
|
+
@patch('tpu_inference.runner.tpu_runner.device_array',
|
|
117
|
+
side_effect=lambda mesh, tensors, **kwargs: tensors)
|
|
118
|
+
@patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
|
|
119
|
+
def test_prepare_inputs_dp_basic_functionality(self,
|
|
120
|
+
mock_sampling_metadata,
|
|
121
|
+
mock_device_array,
|
|
122
|
+
mock_runner_utils,
|
|
123
|
+
mock_named_sharding):
|
|
124
|
+
"""Test basic functionality of _prepare_inputs_dp."""
|
|
125
|
+
# Mock utility functions
|
|
126
|
+
mock_runner_utils.get_padded_token_len.return_value = 16
|
|
127
|
+
mock_sampling_metadata.from_input_batch.return_value = MagicMock()
|
|
128
|
+
mock_named_sharding.return_value = MagicMock()
|
|
129
|
+
|
|
130
|
+
# Create test data - only use req1 and req2 to match num_reqs=2
|
|
131
|
+
num_scheduled_tokens = {"req1": 5, "req2": 3}
|
|
132
|
+
assigned_dp_ranks = {"req1": 0, "req2": 1}
|
|
133
|
+
scheduler_output = self._create_mock_scheduler_output(
|
|
134
|
+
num_scheduled_tokens, assigned_dp_ranks)
|
|
135
|
+
|
|
136
|
+
# Execute the method
|
|
137
|
+
result = self.runner._prepare_inputs_dp(scheduler_output)
|
|
138
|
+
|
|
139
|
+
# Basic assertions
|
|
140
|
+
assert len(result) == 8
|
|
141
|
+
input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
|
|
142
|
+
|
|
143
|
+
# Verify utility functions were called
|
|
144
|
+
mock_runner_utils.get_padded_token_len.assert_called()
|
|
145
|
+
|
|
146
|
+
def test_prepare_inputs_dp_error_conditions(self):
|
|
147
|
+
"""Test error handling in DP input preparation."""
|
|
148
|
+
# Test with zero scheduled tokens - should fail assertion: total_num_scheduled_tokens > 0
|
|
149
|
+
scheduler_output = self._create_mock_scheduler_output({}, {})
|
|
150
|
+
scheduler_output.total_num_scheduled_tokens = 0
|
|
151
|
+
|
|
152
|
+
with pytest.raises(AssertionError):
|
|
153
|
+
self.runner._prepare_inputs_dp(scheduler_output)
|
|
154
|
+
|
|
155
|
+
# Test with zero requests - should fail assertion: num_reqs > 0
|
|
156
|
+
self.runner.input_batch.num_reqs = 0
|
|
157
|
+
scheduler_output = self._create_mock_scheduler_output({"req1": 5},
|
|
158
|
+
{"req1": 0})
|
|
159
|
+
|
|
160
|
+
with pytest.raises(AssertionError):
|
|
161
|
+
self.runner._prepare_inputs_dp(scheduler_output)
|
|
162
|
+
|
|
163
|
+
@patch('tpu_inference.runner.tpu_runner.NamedSharding')
|
|
164
|
+
@patch('tpu_inference.runner.tpu_runner.runner_utils')
|
|
165
|
+
@patch('tpu_inference.runner.tpu_runner.device_array',
|
|
166
|
+
side_effect=lambda mesh, tensors, **kwargs: tensors)
|
|
167
|
+
@patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
|
|
168
|
+
def test_prepare_inputs_dp_hybrid_kvcache(self, mock_sampling_metadata,
|
|
169
|
+
mock_device_array,
|
|
170
|
+
mock_runner_utils,
|
|
171
|
+
mock_named_sharding):
|
|
172
|
+
"""Test basic functionality of _prepare_inputs_dp."""
|
|
173
|
+
# Mock utility functions
|
|
174
|
+
mock_runner_utils.get_padded_token_len.return_value = 16
|
|
175
|
+
mock_sampling_metadata.from_input_batch.return_value = MagicMock()
|
|
176
|
+
mock_named_sharding.return_value = MagicMock()
|
|
177
|
+
|
|
178
|
+
# Create test data - only use req1 and req2 to match num_reqs=2
|
|
179
|
+
num_scheduled_tokens = {"req1": 5, "req2": 3}
|
|
180
|
+
assigned_dp_ranks = {"req1": 0, "req2": 1}
|
|
181
|
+
scheduler_output = self._create_mock_scheduler_output(
|
|
182
|
+
num_scheduled_tokens, assigned_dp_ranks)
|
|
183
|
+
|
|
184
|
+
# Create hybrid kv cache config with 10 full attn layers, 10 sw attn layers
|
|
185
|
+
self._create_mock_hybrid_kv_cache_config()
|
|
186
|
+
|
|
187
|
+
# update input_batch's block_table
|
|
188
|
+
mock_block_table = MagicMock()
|
|
189
|
+
mock_block_table.get_cpu_tensor.return_value = np.arange(32).reshape(
|
|
190
|
+
4, 8)
|
|
191
|
+
self.runner.input_batch.block_table = [
|
|
192
|
+
mock_block_table, mock_block_table
|
|
193
|
+
]
|
|
194
|
+
|
|
195
|
+
# update model runner's block_tables_cpu:
|
|
196
|
+
self.runner.block_tables_cpu = [
|
|
197
|
+
np.zeros((8, 8), dtype=np.int32),
|
|
198
|
+
np.zeros((8, 8), dtype=np.int32)
|
|
199
|
+
]
|
|
200
|
+
|
|
201
|
+
# Execute the method
|
|
202
|
+
result = self.runner._prepare_inputs_dp(scheduler_output)
|
|
203
|
+
|
|
204
|
+
# Basic assertions
|
|
205
|
+
assert len(result) == 8
|
|
206
|
+
input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
|
|
207
|
+
|
|
208
|
+
# Verify utility functions were called
|
|
209
|
+
mock_runner_utils.get_padded_token_len.assert_called()
|
|
210
|
+
|
|
211
|
+
# Verify there's attention_metadata for each layer
|
|
212
|
+
assert isinstance(attention_metadata, dict)
|
|
213
|
+
assert len(attention_metadata) == 20
|
|
214
|
+
|
|
215
|
+
def test_prepare_dp_input_metadata(self):
|
|
216
|
+
num_scheduled_tokens = {"req1": 10, "req2": 5, "req3": 8, "req4": 3}
|
|
217
|
+
assigned_dp_ranks = {"req1": 0, "req2": 0, "req3": 1, "req4": 1}
|
|
218
|
+
|
|
219
|
+
self.runner.input_batch.num_reqs = 4
|
|
220
|
+
self.runner.input_batch.req_ids = ["req1", "req2", "req3", "req4"]
|
|
221
|
+
self.runner.max_num_reqs = 8
|
|
222
|
+
|
|
223
|
+
scheduler_output = self._create_mock_scheduler_output(
|
|
224
|
+
num_scheduled_tokens, assigned_dp_ranks)
|
|
225
|
+
|
|
226
|
+
with patch('tpu_inference.runner.tpu_runner.runner_utils'
|
|
227
|
+
) as mock_runner_utils:
|
|
228
|
+
mock_runner_utils.get_padded_token_len.side_effect = lambda paddings_list, val: 16 if val <= 15 else 32 # Padded tokens per DP rank
|
|
229
|
+
|
|
230
|
+
result = self.runner._prepare_dp_input_metadata(scheduler_output)
|
|
231
|
+
|
|
232
|
+
(req_ids_dp, req_indices_dp, num_scheduled_tokens_per_dp_rank,
|
|
233
|
+
scheduled_tokens_per_dp_rank, num_req_per_dp_rank,
|
|
234
|
+
padded_num_scheduled_tokens_per_dp_rank, padded_num_reqs,
|
|
235
|
+
padded_total_num_scheduled_tokens, padded_num_reqs_per_dp_rank,
|
|
236
|
+
logits_indices_selector, max_num_reqs_per_dp_rank) = result
|
|
237
|
+
|
|
238
|
+
# 1. req_ids_dp: Dictionary mapping DP rank to request IDs
|
|
239
|
+
assert isinstance(req_ids_dp, dict)
|
|
240
|
+
assert req_ids_dp[0] == ["req1", "req2"]
|
|
241
|
+
assert req_ids_dp[1] == ["req3", "req4"]
|
|
242
|
+
|
|
243
|
+
# 2. req_indices_dp: Dictionary mapping DP rank to request indices
|
|
244
|
+
assert isinstance(req_indices_dp, dict)
|
|
245
|
+
assert req_indices_dp[0] == [0, 1] # indices of req1, req2
|
|
246
|
+
assert req_indices_dp[1] == [2, 3] # indices of req3, req4
|
|
247
|
+
|
|
248
|
+
# 3. num_scheduled_tokens_per_dp_rank: Total tokens per DP rank
|
|
249
|
+
assert isinstance(num_scheduled_tokens_per_dp_rank, dict)
|
|
250
|
+
assert num_scheduled_tokens_per_dp_rank[0] == 15 # 10 + 5
|
|
251
|
+
assert num_scheduled_tokens_per_dp_rank[1] == 11 # 8 + 3
|
|
252
|
+
|
|
253
|
+
# 4. scheduled_tokens_per_dp_rank: List of token counts per request per DP rank
|
|
254
|
+
assert isinstance(scheduled_tokens_per_dp_rank, dict)
|
|
255
|
+
assert scheduled_tokens_per_dp_rank[0] == [10,
|
|
256
|
+
5] # req1=10, req2=5
|
|
257
|
+
assert scheduled_tokens_per_dp_rank[1] == [8, 3] # req3=8, req4=3
|
|
258
|
+
|
|
259
|
+
# 5. num_req_per_dp_rank: Number of requests per DP rank
|
|
260
|
+
assert isinstance(num_req_per_dp_rank, dict)
|
|
261
|
+
assert num_req_per_dp_rank[0] == 2
|
|
262
|
+
assert num_req_per_dp_rank[1] == 2
|
|
263
|
+
|
|
264
|
+
# 6. padded_num_scheduled_tokens_per_dp_rank: Padded token count per rank
|
|
265
|
+
assert padded_num_scheduled_tokens_per_dp_rank == 16
|
|
266
|
+
|
|
267
|
+
# 7. padded_num_reqs: Total padded requests across all ranks
|
|
268
|
+
assert padded_num_reqs == 32 # 2 DP ranks * 16 padded reqs per rank
|
|
269
|
+
|
|
270
|
+
# 8. padded_total_num_scheduled_tokens: Total padded tokens across all ranks
|
|
271
|
+
assert padded_total_num_scheduled_tokens == 32 # 2 DP ranks * 16 padded tokens per rank
|
|
272
|
+
|
|
273
|
+
# 9. padded_num_reqs_per_dp_rank: Padded requests per DP rank
|
|
274
|
+
assert padded_num_reqs_per_dp_rank == 16
|
|
275
|
+
|
|
276
|
+
# 10. logits_indices_selector: Array to map back to original request order
|
|
277
|
+
assert isinstance(logits_indices_selector, np.ndarray)
|
|
278
|
+
assert len(logits_indices_selector) == 4 # One for each request
|
|
279
|
+
# Should map distributed positions back to original order
|
|
280
|
+
expected_selector = np.array([0, 1, 16, 17])
|
|
281
|
+
np.testing.assert_array_equal(logits_indices_selector,
|
|
282
|
+
expected_selector)
|
|
283
|
+
|
|
284
|
+
# 11. max_num_reqs_per_dp_rank: Maximum requests per DP rank
|
|
285
|
+
assert max_num_reqs_per_dp_rank == 4 # max_num_reqs (8) // dp_size (2)
|
|
286
|
+
|
|
287
|
+
def test_prepare_dp_input_metadata_empty_rank(self):
|
|
288
|
+
"""Test metadata preparation with one empty DP rank"""
|
|
289
|
+
# Create test data where all requests go to rank 0, leaving rank 1 empty
|
|
290
|
+
num_scheduled_tokens = {"req1": 10, "req2": 5}
|
|
291
|
+
assigned_dp_ranks = {"req1": 0, "req2": 0}
|
|
292
|
+
|
|
293
|
+
self.runner.input_batch.num_reqs = 2
|
|
294
|
+
self.runner.input_batch.req_ids = ["req1", "req2"]
|
|
295
|
+
self.runner.max_num_reqs = 8
|
|
296
|
+
|
|
297
|
+
scheduler_output = self._create_mock_scheduler_output(
|
|
298
|
+
num_scheduled_tokens, assigned_dp_ranks)
|
|
299
|
+
|
|
300
|
+
with patch('tpu_inference.runner.tpu_runner.runner_utils'
|
|
301
|
+
) as mock_runner_utils:
|
|
302
|
+
mock_runner_utils.get_padded_token_len.side_effect = lambda paddings_list, val: 16 if val <= 15 else 32
|
|
303
|
+
|
|
304
|
+
result = self.runner._prepare_dp_input_metadata(scheduler_output)
|
|
305
|
+
|
|
306
|
+
(req_ids_dp, req_indices_dp, num_scheduled_tokens_per_dp_rank,
|
|
307
|
+
scheduled_tokens_per_dp_rank, num_req_per_dp_rank,
|
|
308
|
+
padded_num_scheduled_tokens_per_dp_rank, padded_num_reqs,
|
|
309
|
+
padded_total_num_scheduled_tokens, padded_num_reqs_per_dp_rank,
|
|
310
|
+
logits_indices_selector, max_num_reqs_per_dp_rank) = result
|
|
311
|
+
|
|
312
|
+
# 1. req_ids_dp
|
|
313
|
+
assert isinstance(req_ids_dp, dict)
|
|
314
|
+
assert req_ids_dp[0] == ["req1", "req2"]
|
|
315
|
+
assert req_ids_dp[1] == [] # Empty rank
|
|
316
|
+
|
|
317
|
+
# 2. req_indices_dp
|
|
318
|
+
assert isinstance(req_indices_dp, dict)
|
|
319
|
+
assert req_indices_dp[0] == [0, 1] # req1, req2 indices
|
|
320
|
+
assert req_indices_dp[1] == [] # Empty rank
|
|
321
|
+
|
|
322
|
+
# 3. num_scheduled_tokens_per_dp_rank
|
|
323
|
+
assert isinstance(num_scheduled_tokens_per_dp_rank, dict)
|
|
324
|
+
assert num_scheduled_tokens_per_dp_rank[0] == 15 # 10 + 5
|
|
325
|
+
assert num_scheduled_tokens_per_dp_rank[1] == 0 # Empty rank
|
|
326
|
+
|
|
327
|
+
# 4. scheduled_tokens_per_dp_rank
|
|
328
|
+
assert isinstance(scheduled_tokens_per_dp_rank, dict)
|
|
329
|
+
assert scheduled_tokens_per_dp_rank[0] == [10,
|
|
330
|
+
5] # req1=10, req2=5
|
|
331
|
+
assert scheduled_tokens_per_dp_rank[1] == [] # Empty rank
|
|
332
|
+
|
|
333
|
+
# 5. num_req_per_dp_rank
|
|
334
|
+
assert isinstance(num_req_per_dp_rank, dict)
|
|
335
|
+
assert num_req_per_dp_rank[0] == 2 # Both requests on rank 0
|
|
336
|
+
assert num_req_per_dp_rank[1] == 0 # No requests on rank 1
|
|
337
|
+
|
|
338
|
+
# 6. padded_num_scheduled_tokens_per_dp_rank
|
|
339
|
+
assert padded_num_scheduled_tokens_per_dp_rank == 16
|
|
340
|
+
|
|
341
|
+
# 7. padded_num_reqs
|
|
342
|
+
assert padded_num_reqs == 32 # 2 DP ranks * 16 padded reqs per rank
|
|
343
|
+
|
|
344
|
+
# 8. padded_total_num_scheduled_tokens
|
|
345
|
+
assert padded_total_num_scheduled_tokens == 32 # 2 DP ranks * 16 padded tokens per rank
|
|
346
|
+
|
|
347
|
+
# 10. padded_num_reqs_per_dp_rank: Padded requests per DP rank
|
|
348
|
+
assert padded_num_reqs_per_dp_rank == 16
|
|
349
|
+
|
|
350
|
+
# 11. logits_indices_selector: Should preserve original order since no reordering needed
|
|
351
|
+
assert isinstance(logits_indices_selector, np.ndarray)
|
|
352
|
+
assert len(logits_indices_selector) == 2
|
|
353
|
+
# Both requests on DP rank 0, positions 0 and 1
|
|
354
|
+
expected_selector = np.array([0, 1])
|
|
355
|
+
np.testing.assert_array_equal(logits_indices_selector,
|
|
356
|
+
expected_selector)
|
|
357
|
+
|
|
358
|
+
# 12. max_num_reqs_per_dp_rank: Maximum requests per DP rank
|
|
359
|
+
assert max_num_reqs_per_dp_rank == 4 # max_num_reqs (8) // dp_size (2)
|
|
360
|
+
|
|
361
|
+
def test_prepare_dp_input_metadata_logits_indices_selector_ordering(self):
|
|
362
|
+
"""Test logits_indices_selector with mixed DP rank assignment."""
|
|
363
|
+
# Create requests with mixed assignment to test reordering
|
|
364
|
+
num_scheduled_tokens = {"req1": 4, "req2": 6, "req3": 2}
|
|
365
|
+
assigned_dp_ranks = {
|
|
366
|
+
"req1": 1,
|
|
367
|
+
"req2": 0,
|
|
368
|
+
"req3": 1
|
|
369
|
+
} # req2 on rank 0, req1&req3 on rank 1
|
|
370
|
+
|
|
371
|
+
self.runner.input_batch.num_reqs = 3
|
|
372
|
+
self.runner.input_batch.req_ids = ["req1", "req2", "req3"]
|
|
373
|
+
|
|
374
|
+
scheduler_output = self._create_mock_scheduler_output(
|
|
375
|
+
num_scheduled_tokens, assigned_dp_ranks)
|
|
376
|
+
|
|
377
|
+
with patch('tpu_inference.runner.tpu_runner.runner_utils'
|
|
378
|
+
) as mock_runner_utils:
|
|
379
|
+
mock_runner_utils.get_padded_token_len.side_effect = lambda paddings_list, val: 8 if val <= 6 else 16
|
|
380
|
+
|
|
381
|
+
result = self.runner._prepare_dp_input_metadata(scheduler_output)
|
|
382
|
+
|
|
383
|
+
(req_ids_dp, req_indices_dp, _, _, _, _, _, _, _,
|
|
384
|
+
logits_indices_selector, _) = result
|
|
385
|
+
|
|
386
|
+
# Verify request distribution
|
|
387
|
+
assert req_ids_dp[0] == ["req2"] # rank 0: req2 (index 1)
|
|
388
|
+
assert req_ids_dp[1] == [
|
|
389
|
+
"req1", "req3"
|
|
390
|
+
] # rank 1: req1 (index 0), req3 (index 2)
|
|
391
|
+
|
|
392
|
+
assert req_indices_dp[0] == [1] # req2 has original index 1
|
|
393
|
+
assert req_indices_dp[1] == [
|
|
394
|
+
0, 2
|
|
395
|
+
] # req1 has index 0, req3 has index 2
|
|
396
|
+
|
|
397
|
+
# The logits_indices_selector should map the DP-distributed positions back to original order
|
|
398
|
+
|
|
399
|
+
assert isinstance(logits_indices_selector, np.ndarray)
|
|
400
|
+
assert len(logits_indices_selector) == 3
|
|
401
|
+
|
|
402
|
+
expected_positions = np.array([8, 0, 9])
|
|
403
|
+
np.testing.assert_array_equal(logits_indices_selector,
|
|
404
|
+
expected_positions)
|
|
405
|
+
|
|
406
|
+
@patch('tpu_inference.runner.tpu_runner.NamedSharding')
|
|
407
|
+
@patch('tpu_inference.runner.tpu_runner.runner_utils')
|
|
408
|
+
@patch('tpu_inference.runner.tpu_runner.device_array',
|
|
409
|
+
side_effect=lambda mesh, tensors, **kwargs: tensors)
|
|
410
|
+
@patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
|
|
411
|
+
def test_prepare_inputs_dp_verify_content_balanced(self,
|
|
412
|
+
mock_sampling_metadata,
|
|
413
|
+
mock_device_array,
|
|
414
|
+
mock_runner_utils,
|
|
415
|
+
mock_named_sharding):
|
|
416
|
+
"""Test _prepare_inputs_dp with content verification for balanced distribution."""
|
|
417
|
+
|
|
418
|
+
# Setup mocking with specific behavior for tokens vs requests
|
|
419
|
+
def mock_get_padded_token_len(paddings_list, val):
|
|
420
|
+
# For tokens: 8 if val <= 3 else 16
|
|
421
|
+
# For requests: 4 if val <= 1 else 8
|
|
422
|
+
if val <= 1:
|
|
423
|
+
return 4 # For request padding
|
|
424
|
+
elif val <= 3:
|
|
425
|
+
return 8 # For token padding
|
|
426
|
+
else:
|
|
427
|
+
return 16
|
|
428
|
+
|
|
429
|
+
mock_runner_utils.get_padded_token_len.side_effect = mock_get_padded_token_len
|
|
430
|
+
mock_sampling_instance = MagicMock()
|
|
431
|
+
mock_sampling_metadata.from_input_batch.return_value = mock_sampling_instance
|
|
432
|
+
mock_named_sharding.return_value = MagicMock()
|
|
433
|
+
|
|
434
|
+
# Setup deterministic test data
|
|
435
|
+
num_scheduled_tokens = {"req1": 2, "req2": 3}
|
|
436
|
+
assigned_dp_ranks = {"req1": 0, "req2": 1}
|
|
437
|
+
|
|
438
|
+
self.runner.input_batch.num_reqs = 2
|
|
439
|
+
self.runner.input_batch.req_ids = ["req1", "req2"]
|
|
440
|
+
self.runner.input_batch.num_computed_tokens_cpu = np.array(
|
|
441
|
+
[5, 6]) # Starting positions
|
|
442
|
+
|
|
443
|
+
# Setup known token sequences for verification
|
|
444
|
+
self.runner.input_batch.token_ids_cpu = np.zeros((8, 64),
|
|
445
|
+
dtype=np.int32)
|
|
446
|
+
# req1: [1001, 1002, 1003, ...]
|
|
447
|
+
# req2: [2001, 2002, 2003, ...]
|
|
448
|
+
for i in range(2):
|
|
449
|
+
start_val = (i + 1) * 1000 + 1
|
|
450
|
+
for j in range(64):
|
|
451
|
+
self.runner.input_batch.token_ids_cpu[i, j] = start_val + j
|
|
452
|
+
|
|
453
|
+
scheduler_output = self._create_mock_scheduler_output(
|
|
454
|
+
num_scheduled_tokens, assigned_dp_ranks)
|
|
455
|
+
|
|
456
|
+
# Setup additional required attributes
|
|
457
|
+
self.runner.uses_mrope = False
|
|
458
|
+
self.runner.phase_based_profiler = None
|
|
459
|
+
self.runner.lora_config = None
|
|
460
|
+
self.runner.mesh = MagicMock()
|
|
461
|
+
self.runner.data_parallel_sharding = MagicMock()
|
|
462
|
+
self.runner.data_parallel_attn_sharding = MagicMock()
|
|
463
|
+
self.runner.mm_manager = MagicMock()
|
|
464
|
+
self.runner.speculative_decoding_manager = MagicMock()
|
|
465
|
+
self.runner.lora_utils = MagicMock()
|
|
466
|
+
# self.runner.mrope_positions_cpu = np.zeros((3, 64), dtype=np.int64)
|
|
467
|
+
|
|
468
|
+
# Execute the method
|
|
469
|
+
result = self.runner._prepare_inputs_dp(scheduler_output)
|
|
470
|
+
input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
|
|
471
|
+
# 1. Verify input_ids content
|
|
472
|
+
expected_input_ids = np.zeros(16, dtype=np.int32)
|
|
473
|
+
expected_input_ids[:2] = [1006, 1007]
|
|
474
|
+
expected_input_ids[8:11] = [2007, 2008, 2009]
|
|
475
|
+
assert np.array_equal(input_ids, expected_input_ids)
|
|
476
|
+
|
|
477
|
+
# 2. Verify attention_metadata positions content
|
|
478
|
+
expected_positions = np.zeros(16, dtype=np.int32)
|
|
479
|
+
expected_positions[:2] = [5, 6] # req1 positions
|
|
480
|
+
expected_positions[8:11] = [6, 7, 8]
|
|
481
|
+
assert np.array_equal(attention_metadata.input_positions,
|
|
482
|
+
expected_positions)
|
|
483
|
+
|
|
484
|
+
# 3. Verify query_start_loc content
|
|
485
|
+
query_start_loc = attention_metadata.query_start_loc_cpu
|
|
486
|
+
max_num_reqs_per_dp = self.runner.max_num_reqs // 2
|
|
487
|
+
expected_query_start = np.zeros(self.runner.max_num_reqs + 2,
|
|
488
|
+
dtype=np.int32)
|
|
489
|
+
# DP rank 0: cumsum([2]) = [2] at positions [1:2] → [0, 2, 1, 1, 1]
|
|
490
|
+
expected_query_start[1] = 2 # req1 has 2 tokens
|
|
491
|
+
expected_query_start[2:max_num_reqs_per_dp + 1] = 1
|
|
492
|
+
# DP rank 1: cumsum([3]) = [3] at positions [6:7] → [0, 3, 1, 1, 1]
|
|
493
|
+
expected_query_start[max_num_reqs_per_dp + 2] = 3 # req2 has 3 tokens
|
|
494
|
+
expected_query_start[max_num_reqs_per_dp + 3:] = 1
|
|
495
|
+
assert np.array_equal(query_start_loc, expected_query_start)
|
|
496
|
+
|
|
497
|
+
# 4. Verify seq_lens content
|
|
498
|
+
seq_lens = attention_metadata.seq_lens_cpu
|
|
499
|
+
# Should be computed_tokens + scheduled_tokens for each request
|
|
500
|
+
# DP rank 0: req1 at position 0, DP rank 1: req2 at position 4
|
|
501
|
+
expected_seq_lens = np.array([7, 0, 0, 0, 9, 0, 0,
|
|
502
|
+
0]) # req1: 5+2=7, req2: 6+3=9
|
|
503
|
+
assert np.array_equal(seq_lens, expected_seq_lens)
|
|
504
|
+
|
|
505
|
+
# 5. Verify request_distribution content
|
|
506
|
+
expected_distribution = np.array([[0, 0, 1], [0, 0, 1]]).flatten()
|
|
507
|
+
np.testing.assert_array_equal(attention_metadata.request_distribution,
|
|
508
|
+
expected_distribution)
|
|
509
|
+
|
|
510
|
+
# 6. Verify logits_indices content
|
|
511
|
+
assert len(logits_indices) == 8 # padded_num_reqs
|
|
512
|
+
expected_logits = np.full(8, -1, dtype=np.int32)
|
|
513
|
+
expected_logits[0] = 1 # req1 last token position (2-1)
|
|
514
|
+
expected_logits[
|
|
515
|
+
4] = 2 # req2 last token position (3-1) at DP rank 1 offset (4*1)
|
|
516
|
+
assert np.array_equal(logits_indices, expected_logits)
|
|
517
|
+
|
|
518
|
+
# 7. Verify logits_indices_selector
|
|
519
|
+
assert len(logits_indices_selector) == 2
|
|
520
|
+
assert np.array_equal(logits_indices_selector, np.array([0, 4]))
|
|
521
|
+
|
|
522
|
+
@patch('tpu_inference.runner.tpu_runner.NamedSharding')
|
|
523
|
+
@patch('tpu_inference.runner.tpu_runner.runner_utils')
|
|
524
|
+
@patch('tpu_inference.runner.tpu_runner.device_array',
|
|
525
|
+
side_effect=lambda mesh, tensors, **kwargs: tensors)
|
|
526
|
+
@patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
|
|
527
|
+
def test_prepare_inputs_dp_verify_content_empty_rank(
|
|
528
|
+
self, mock_sampling_metadata, mock_device_array, mock_runner_utils,
|
|
529
|
+
mock_named_sharding):
|
|
530
|
+
"""Test _prepare_inputs_dp with detailed content verification for empty rank case."""
|
|
531
|
+
|
|
532
|
+
# Setup mocking
|
|
533
|
+
def mock_get_padded_token_len(paddings_list, val):
|
|
534
|
+
if val <= 2:
|
|
535
|
+
return 4 # For request padding (max 2 requests)
|
|
536
|
+
elif val <= 5:
|
|
537
|
+
return 8 # For token padding
|
|
538
|
+
else:
|
|
539
|
+
return 16
|
|
540
|
+
|
|
541
|
+
mock_runner_utils.get_padded_token_len.side_effect = mock_get_padded_token_len
|
|
542
|
+
mock_sampling_instance = MagicMock()
|
|
543
|
+
mock_sampling_metadata.from_input_batch.return_value = mock_sampling_instance
|
|
544
|
+
mock_named_sharding.return_value = MagicMock()
|
|
545
|
+
|
|
546
|
+
# Setup test data with all requests on rank 0 (empty rank 1)
|
|
547
|
+
num_scheduled_tokens = {"req1": 3, "req2": 2}
|
|
548
|
+
assigned_dp_ranks = {
|
|
549
|
+
"req1": 0,
|
|
550
|
+
"req2": 0
|
|
551
|
+
} # Both on rank 0, rank 1 empty
|
|
552
|
+
|
|
553
|
+
self.runner.input_batch.num_reqs = 2
|
|
554
|
+
self.runner.input_batch.req_ids = ["req1", "req2"]
|
|
555
|
+
self.runner.input_batch.num_computed_tokens_cpu = np.array(
|
|
556
|
+
[4, 6]) # Starting positions
|
|
557
|
+
|
|
558
|
+
# Setup deterministic token sequences for verification
|
|
559
|
+
self.runner.input_batch.token_ids_cpu = np.zeros((8, 64),
|
|
560
|
+
dtype=np.int32)
|
|
561
|
+
# req1: [5001, 5002, 5003, ...] starting at position 4
|
|
562
|
+
# req2: [6001, 6002, 6003, ...] starting at position 6
|
|
563
|
+
for i in range(2):
|
|
564
|
+
start_val = (i + 5) * 1000 + 1 # 5001, 6001
|
|
565
|
+
for j in range(64):
|
|
566
|
+
self.runner.input_batch.token_ids_cpu[i, j] = start_val + j
|
|
567
|
+
|
|
568
|
+
scheduler_output = self._create_mock_scheduler_output(
|
|
569
|
+
num_scheduled_tokens, assigned_dp_ranks)
|
|
570
|
+
|
|
571
|
+
# Setup required attributes
|
|
572
|
+
self.runner.uses_mrope = False
|
|
573
|
+
self.runner.phase_based_profiler = None
|
|
574
|
+
self.runner.lora_config = None
|
|
575
|
+
self.runner.mesh = MagicMock()
|
|
576
|
+
self.runner.data_parallel_sharding = MagicMock()
|
|
577
|
+
self.runner.data_parallel_attn_sharding = MagicMock()
|
|
578
|
+
self.runner.mm_manager = MagicMock()
|
|
579
|
+
self.runner.speculative_decoding_manager = MagicMock()
|
|
580
|
+
self.runner.lora_utils = MagicMock()
|
|
581
|
+
|
|
582
|
+
# Execute the method
|
|
583
|
+
result = self.runner._prepare_inputs_dp(scheduler_output)
|
|
584
|
+
input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
|
|
585
|
+
|
|
586
|
+
# 1. Verify input_ids
|
|
587
|
+
expected_input_ids = np.zeros(16, dtype=np.int32)
|
|
588
|
+
# Rank 0
|
|
589
|
+
expected_input_ids[:5] = [5005, 5006, 5007, 6007, 6008]
|
|
590
|
+
# Rank 1 (positions 8-15) should remain zeros
|
|
591
|
+
assert np.array_equal(input_ids, expected_input_ids)
|
|
592
|
+
|
|
593
|
+
# 2. Verify attention_metadata
|
|
594
|
+
expected_positions = np.zeros(16, dtype=np.int32)
|
|
595
|
+
expected_positions[:3] = [4, 5, 6] # req1 positions: 4 + [0, 1, 2]
|
|
596
|
+
expected_positions[3:5] = [6, 7] # req2 positions: 6 + [0, 1]
|
|
597
|
+
# Rank 1 positions (8-15) remain zeros
|
|
598
|
+
assert np.array_equal(attention_metadata.input_positions,
|
|
599
|
+
expected_positions)
|
|
600
|
+
|
|
601
|
+
# 3. Verify query_start_loc
|
|
602
|
+
query_start_loc = attention_metadata.query_start_loc_cpu
|
|
603
|
+
max_num_reqs_per_dp = self.runner.max_num_reqs // 2 # 4
|
|
604
|
+
expected_query_start = np.zeros(self.runner.max_num_reqs + 2,
|
|
605
|
+
dtype=np.int32)
|
|
606
|
+
# Rank 0: req1 (3 tokens), req2 (2 tokens)
|
|
607
|
+
expected_query_start[1] = 3 # req1 has 3 tokens
|
|
608
|
+
expected_query_start[2] = 5 # cumulative: 3 + 2 = 5
|
|
609
|
+
expected_query_start[3:max_num_reqs_per_dp + 1] = 1 # padding
|
|
610
|
+
# Rank 1: empty (all zeros)
|
|
611
|
+
expected_query_start[max_num_reqs_per_dp +
|
|
612
|
+
1:] = 0 # Empty rank sets to 0
|
|
613
|
+
assert np.array_equal(query_start_loc, expected_query_start)
|
|
614
|
+
|
|
615
|
+
# 4. Verify seq_lens
|
|
616
|
+
seq_lens = attention_metadata.seq_lens_cpu
|
|
617
|
+
expected_seq_lens = np.zeros(8, dtype=np.int32)
|
|
618
|
+
# Rank 0: req1 (4+3=7), req2 (6+2=8), then padding
|
|
619
|
+
expected_seq_lens[
|
|
620
|
+
0] = 7 # req1: computed_tokens(4) + scheduled_tokens(3)
|
|
621
|
+
expected_seq_lens[
|
|
622
|
+
1] = 8 # req2: computed_tokens(6) + scheduled_tokens(2)
|
|
623
|
+
# Rank 1: all zeros
|
|
624
|
+
assert np.array_equal(seq_lens, expected_seq_lens)
|
|
625
|
+
|
|
626
|
+
# 5. Verify request_distribution
|
|
627
|
+
expected_distribution = np.array([[0, 0, 2], [0, 0, 0]]).flatten()
|
|
628
|
+
np.testing.assert_array_equal(attention_metadata.request_distribution,
|
|
629
|
+
expected_distribution)
|
|
630
|
+
|
|
631
|
+
# 6. Verify logits_indices
|
|
632
|
+
assert len(
|
|
633
|
+
logits_indices) == 8 # padded_num_reqs (8 in this case, not 16)
|
|
634
|
+
# Rank 0: req1 ends at pos 2, req2 ends at pos 4
|
|
635
|
+
# Rank 1: empty, so -1 padding
|
|
636
|
+
expected_logits = np.full(8, -1, dtype=np.int32)
|
|
637
|
+
expected_logits[0] = 2 # req1 ends at position 2 (3-1)
|
|
638
|
+
expected_logits[1] = 4 # req2 ends at position 4 (5-1)
|
|
639
|
+
assert np.array_equal(logits_indices, expected_logits)
|
|
640
|
+
|
|
641
|
+
# 7. Verify logits_indices_selector
|
|
642
|
+
assert len(logits_indices_selector) == 2
|
|
643
|
+
expected_selector = np.array([0, 1])
|
|
644
|
+
np.testing.assert_array_equal(logits_indices_selector,
|
|
645
|
+
expected_selector)
|
|
646
|
+
|
|
647
|
+
@patch('tpu_inference.runner.tpu_runner.NamedSharding')
|
|
648
|
+
@patch('tpu_inference.runner.tpu_runner.runner_utils')
|
|
649
|
+
@patch('tpu_inference.runner.tpu_runner.device_array',
|
|
650
|
+
side_effect=lambda mesh, tensors, **kwargs: tensors)
|
|
651
|
+
@patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
|
|
652
|
+
def test_prepare_inputs_dp_with_decode_requests(self,
|
|
653
|
+
mock_sampling_metadata,
|
|
654
|
+
mock_device_array,
|
|
655
|
+
mock_runner_utils,
|
|
656
|
+
mock_named_sharding):
|
|
657
|
+
"""Test _prepare_inputs_dp with decode requests (1 token each) to verify request_distribution."""
|
|
658
|
+
|
|
659
|
+
# Setup mocking
|
|
660
|
+
def mock_get_padded_token_len(paddings_list, val):
|
|
661
|
+
if val <= 2:
|
|
662
|
+
return 4 # For request padding
|
|
663
|
+
elif val <= 4:
|
|
664
|
+
return 8 # For token padding
|
|
665
|
+
else:
|
|
666
|
+
return 16
|
|
667
|
+
|
|
668
|
+
mock_runner_utils.get_padded_token_len.side_effect = mock_get_padded_token_len
|
|
669
|
+
mock_sampling_instance = MagicMock()
|
|
670
|
+
mock_sampling_metadata.from_input_batch.return_value = mock_sampling_instance
|
|
671
|
+
mock_named_sharding.return_value = MagicMock()
|
|
672
|
+
|
|
673
|
+
# Setup test data with decode requests (1 token) and prefill requests (>1 token)
|
|
674
|
+
# req1: decode (1 token), req2: decode (1 token), req3: prefill (3 tokens), req4: decode (1 token)
|
|
675
|
+
num_scheduled_tokens = {"req1": 1, "req2": 1, "req3": 3, "req4": 1}
|
|
676
|
+
assigned_dp_ranks = {"req1": 0, "req2": 0, "req3": 1, "req4": 1}
|
|
677
|
+
|
|
678
|
+
self.runner.input_batch.num_reqs = 4
|
|
679
|
+
self.runner.input_batch.req_ids = ["req1", "req2", "req3", "req4"]
|
|
680
|
+
self.runner.input_batch.num_computed_tokens_cpu = np.array(
|
|
681
|
+
[5, 6, 7, 8])
|
|
682
|
+
self.runner.input_batch.token_ids_cpu = np.zeros((8, 64),
|
|
683
|
+
dtype=np.int32)
|
|
684
|
+
|
|
685
|
+
scheduler_output = self._create_mock_scheduler_output(
|
|
686
|
+
num_scheduled_tokens, assigned_dp_ranks)
|
|
687
|
+
|
|
688
|
+
# Setup required attributes
|
|
689
|
+
self.runner.uses_mrope = False
|
|
690
|
+
self.runner.phase_based_profiler = None
|
|
691
|
+
self.runner.lora_config = None
|
|
692
|
+
self.runner.mesh = MagicMock()
|
|
693
|
+
self.runner.data_parallel_sharding = MagicMock()
|
|
694
|
+
self.runner.data_parallel_attn_sharding = MagicMock()
|
|
695
|
+
self.runner.mm_manager = MagicMock()
|
|
696
|
+
self.runner.speculative_decoding_manager = MagicMock()
|
|
697
|
+
self.runner.lora_utils = MagicMock()
|
|
698
|
+
|
|
699
|
+
# Execute the method
|
|
700
|
+
result = self.runner._prepare_inputs_dp(scheduler_output)
|
|
701
|
+
input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
|
|
702
|
+
|
|
703
|
+
# Verify request_distribution
|
|
704
|
+
# DP rank 0: req1 (decode), req2 (decode) -> [2, 2, 2]
|
|
705
|
+
# DP rank 1: req3 (prefill), req4 (decode) -> [1, 1, 2]
|
|
706
|
+
expected_distribution = np.array([[2, 2, 2], [1, 1, 2]]).flatten()
|
|
707
|
+
np.testing.assert_array_equal(attention_metadata.request_distribution,
|
|
708
|
+
expected_distribution)
|
|
709
|
+
|
|
710
|
+
@patch('tpu_inference.runner.tpu_runner.NamedSharding')
|
|
711
|
+
@patch('tpu_inference.runner.tpu_runner.runner_utils')
|
|
712
|
+
@patch('tpu_inference.runner.tpu_runner.device_array',
|
|
713
|
+
side_effect=lambda mesh, tensors, **kwargs: tensors)
|
|
714
|
+
@patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
|
|
715
|
+
def test_prepare_inputs_dp_all_decode_requests(self,
|
|
716
|
+
mock_sampling_metadata,
|
|
717
|
+
mock_device_array,
|
|
718
|
+
mock_runner_utils,
|
|
719
|
+
mock_named_sharding):
|
|
720
|
+
"""Test _prepare_inputs_dp with all decode requests."""
|
|
721
|
+
|
|
722
|
+
# Setup mocking
|
|
723
|
+
def mock_get_padded_token_len(paddings_list, val):
|
|
724
|
+
if val <= 2:
|
|
725
|
+
return 4
|
|
726
|
+
elif val <= 4:
|
|
727
|
+
return 8
|
|
728
|
+
else:
|
|
729
|
+
return 16
|
|
730
|
+
|
|
731
|
+
mock_runner_utils.get_padded_token_len.side_effect = mock_get_padded_token_len
|
|
732
|
+
mock_sampling_instance = MagicMock()
|
|
733
|
+
mock_sampling_metadata.from_input_batch.return_value = mock_sampling_instance
|
|
734
|
+
mock_named_sharding.return_value = MagicMock()
|
|
735
|
+
|
|
736
|
+
# All requests are decode (1 token each)
|
|
737
|
+
num_scheduled_tokens = {"req1": 1, "req2": 1}
|
|
738
|
+
assigned_dp_ranks = {"req1": 0, "req2": 1}
|
|
739
|
+
|
|
740
|
+
self.runner.input_batch.num_reqs = 2
|
|
741
|
+
self.runner.input_batch.req_ids = ["req1", "req2"]
|
|
742
|
+
self.runner.input_batch.num_computed_tokens_cpu = np.array([5, 6])
|
|
743
|
+
self.runner.input_batch.token_ids_cpu = np.zeros((8, 64),
|
|
744
|
+
dtype=np.int32)
|
|
745
|
+
|
|
746
|
+
scheduler_output = self._create_mock_scheduler_output(
|
|
747
|
+
num_scheduled_tokens, assigned_dp_ranks)
|
|
748
|
+
|
|
749
|
+
# Setup required attributes
|
|
750
|
+
self.runner.uses_mrope = False
|
|
751
|
+
self.runner.phase_based_profiler = None
|
|
752
|
+
self.runner.lora_config = None
|
|
753
|
+
self.runner.mesh = MagicMock()
|
|
754
|
+
self.runner.data_parallel_sharding = MagicMock()
|
|
755
|
+
self.runner.data_parallel_attn_sharding = MagicMock()
|
|
756
|
+
self.runner.mm_manager = MagicMock()
|
|
757
|
+
self.runner.speculative_decoding_manager = MagicMock()
|
|
758
|
+
self.runner.lora_utils = MagicMock()
|
|
759
|
+
|
|
760
|
+
# Execute the method
|
|
761
|
+
result = self.runner._prepare_inputs_dp(scheduler_output)
|
|
762
|
+
input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
|
|
763
|
+
|
|
764
|
+
# Verify request_distribution
|
|
765
|
+
# Both ranks have only decode requests
|
|
766
|
+
# DP rank 0: req1 (decode) -> [1, 1, 1]
|
|
767
|
+
# DP rank 1: req2 (decode) -> [1, 1, 1]
|
|
768
|
+
expected_distribution = np.array([[1, 1, 1], [1, 1, 1]]).flatten()
|
|
769
|
+
np.testing.assert_array_equal(attention_metadata.request_distribution,
|
|
770
|
+
expected_distribution)
|
|
771
|
+
|
|
772
|
+
@patch('tpu_inference.runner.tpu_runner.NamedSharding')
|
|
773
|
+
@patch('tpu_inference.runner.tpu_runner.runner_utils')
|
|
774
|
+
@patch('tpu_inference.runner.tpu_runner.device_array',
|
|
775
|
+
side_effect=lambda mesh, tensors, **kwargs: tensors)
|
|
776
|
+
@patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
|
|
777
|
+
def test_prepare_async_token_substitution_indices_dp(
|
|
778
|
+
self, mock_sampling_metadata, mock_device_array, mock_runner_utils,
|
|
779
|
+
mock_named_sharding):
|
|
780
|
+
|
|
781
|
+
# Setup test data
|
|
782
|
+
req_ids_dp = {0: ["req1", "req2"], 1: ["req3"]}
|
|
783
|
+
scheduled_tokens_per_dp_rank = {0: [3, 2], 1: [4]}
|
|
784
|
+
padded_num_scheduled_tokens_per_dp_rank = 8
|
|
785
|
+
dp_size = 2
|
|
786
|
+
|
|
787
|
+
# Setup _pre_async_results with placeholder mapping
|
|
788
|
+
self.runner._pre_async_results = MagicMock()
|
|
789
|
+
self.runner._pre_async_results.placeholder_req_id_to_index = {
|
|
790
|
+
"req1": 0,
|
|
791
|
+
"req3": 2
|
|
792
|
+
} # req2 is not a placeholder
|
|
793
|
+
|
|
794
|
+
# Call the method
|
|
795
|
+
result = self.runner._prepare_async_token_substitution_indices_dp(
|
|
796
|
+
req_ids_dp, scheduled_tokens_per_dp_rank,
|
|
797
|
+
padded_num_scheduled_tokens_per_dp_rank, dp_size)
|
|
798
|
+
|
|
799
|
+
token_in_tpu_cur_input_indices_dp, token_in_tpu_pre_next_tokens_indices_dp = result
|
|
800
|
+
|
|
801
|
+
# Verify DP rank 0
|
|
802
|
+
# req1: token_offset=0, acc_cur_len starts at 0, after 3 tokens: 3, so last token at 2
|
|
803
|
+
# req2: not a placeholder, should be skipped
|
|
804
|
+
assert token_in_tpu_cur_input_indices_dp[0] == [2]
|
|
805
|
+
assert token_in_tpu_pre_next_tokens_indices_dp[0] == [0]
|
|
806
|
+
|
|
807
|
+
# Verify DP rank 1
|
|
808
|
+
# req3: token_offset=8, acc_cur_len starts at 8, after 4 tokens: 12, so last token at 11
|
|
809
|
+
assert token_in_tpu_cur_input_indices_dp[1] == [11]
|
|
810
|
+
assert token_in_tpu_pre_next_tokens_indices_dp[1] == [2]
|
|
811
|
+
|
|
812
|
+
@patch('tpu_inference.runner.tpu_runner.NamedSharding')
|
|
813
|
+
@patch('tpu_inference.runner.tpu_runner.runner_utils')
|
|
814
|
+
@patch('tpu_inference.runner.tpu_runner.device_array',
|
|
815
|
+
side_effect=lambda mesh, tensors, **kwargs: tensors)
|
|
816
|
+
@patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
|
|
817
|
+
def test_prepare_async_token_substitution_indices_dp_no_placeholders(
|
|
818
|
+
self, mock_sampling_metadata, mock_device_array, mock_runner_utils,
|
|
819
|
+
mock_named_sharding):
|
|
820
|
+
"""Test when no requests are placeholders."""
|
|
821
|
+
|
|
822
|
+
req_ids_dp = {0: ["req1", "req2"], 1: ["req3"]}
|
|
823
|
+
scheduled_tokens_per_dp_rank = {0: [3, 2], 1: [4]}
|
|
824
|
+
padded_num_scheduled_tokens_per_dp_rank = 8
|
|
825
|
+
dp_size = 2
|
|
826
|
+
|
|
827
|
+
# No placeholders
|
|
828
|
+
self.runner._pre_async_results = MagicMock()
|
|
829
|
+
self.runner._pre_async_results.placeholder_req_id_to_index = {}
|
|
830
|
+
|
|
831
|
+
result = self.runner._prepare_async_token_substitution_indices_dp(
|
|
832
|
+
req_ids_dp, scheduled_tokens_per_dp_rank,
|
|
833
|
+
padded_num_scheduled_tokens_per_dp_rank, dp_size)
|
|
834
|
+
|
|
835
|
+
token_in_tpu_cur_input_indices_dp, token_in_tpu_pre_next_tokens_indices_dp = result
|
|
836
|
+
|
|
837
|
+
# All lists should be empty since no placeholders
|
|
838
|
+
assert token_in_tpu_cur_input_indices_dp[0] == []
|
|
839
|
+
assert token_in_tpu_pre_next_tokens_indices_dp[0] == []
|
|
840
|
+
assert token_in_tpu_cur_input_indices_dp[1] == []
|
|
841
|
+
assert token_in_tpu_pre_next_tokens_indices_dp[1] == []
|
|
842
|
+
|
|
843
|
+
def test_apply_async_token_substitution_empty_indices(self):
|
|
844
|
+
"""Test _apply_async_token_substitution with empty indices (line 1025)."""
|
|
845
|
+
|
|
846
|
+
# Bind the actual method
|
|
847
|
+
self.runner._apply_async_token_substitution = TPUModelRunner._apply_async_token_substitution.__get__(
|
|
848
|
+
self.runner)
|
|
849
|
+
|
|
850
|
+
input_ids = np.array([1, 2, 3, 4, 5])
|
|
851
|
+
token_in_tpu_cur_input_indices = np.array([])
|
|
852
|
+
token_in_tpu_pre_next_tokens_indices = np.array([])
|
|
853
|
+
|
|
854
|
+
# Setup _pre_async_results
|
|
855
|
+
self.runner._pre_async_results = MagicMock()
|
|
856
|
+
self.runner._pre_async_results.next_tokens = np.array([10, 20, 30])
|
|
857
|
+
self.runner.mesh = MagicMock()
|
|
858
|
+
|
|
859
|
+
result = self.runner._apply_async_token_substitution(
|
|
860
|
+
input_ids, token_in_tpu_cur_input_indices,
|
|
861
|
+
token_in_tpu_pre_next_tokens_indices)
|
|
862
|
+
|
|
863
|
+
# Should return input_ids unchanged
|
|
864
|
+
np.testing.assert_array_equal(result, input_ids)
|
|
865
|
+
|
|
866
|
+
@patch('tpu_inference.runner.tpu_runner.device_array',
|
|
867
|
+
side_effect=lambda mesh, tensors, **kwargs: tensors)
|
|
868
|
+
def test_apply_async_token_substitution_with_padding(
|
|
869
|
+
self, mock_device_array):
|
|
870
|
+
"""Test _apply_async_token_substitution with padding."""
|
|
871
|
+
|
|
872
|
+
# Bind the actual method
|
|
873
|
+
self.runner._apply_async_token_substitution = TPUModelRunner._apply_async_token_substitution.__get__(
|
|
874
|
+
self.runner)
|
|
875
|
+
|
|
876
|
+
input_ids = np.array([1, 2, 3, 4, 5, 6, 7, 8])
|
|
877
|
+
# Substitute positions 2 and 5
|
|
878
|
+
token_in_tpu_cur_input_indices = np.array([2, 5])
|
|
879
|
+
token_in_tpu_pre_next_tokens_indices = np.array([0, 1])
|
|
880
|
+
|
|
881
|
+
# Setup _pre_async_results
|
|
882
|
+
self.runner._pre_async_results = MagicMock()
|
|
883
|
+
self.runner._pre_async_results.next_tokens = np.array([100, 200, 300])
|
|
884
|
+
self.runner.mesh = MagicMock()
|
|
885
|
+
self.runner.maybe_forbid_compile = nullcontext()
|
|
886
|
+
|
|
887
|
+
# Mock the substitute function to verify it's called correctly
|
|
888
|
+
mock_substitute_fn = MagicMock(
|
|
889
|
+
return_value=np.array([1, 2, 100, 4, 5, 200, 7, 8]))
|
|
890
|
+
self.runner._substitute_placeholder_token_fn = mock_substitute_fn
|
|
891
|
+
|
|
892
|
+
_ = self.runner._apply_async_token_substitution(
|
|
893
|
+
input_ids, token_in_tpu_cur_input_indices,
|
|
894
|
+
token_in_tpu_pre_next_tokens_indices)
|
|
895
|
+
|
|
896
|
+
# Verify the substitute function was called
|
|
897
|
+
mock_substitute_fn.assert_called_once()
|
|
898
|
+
call_args = mock_substitute_fn.call_args[0]
|
|
899
|
+
|
|
900
|
+
# Verify input_ids
|
|
901
|
+
np.testing.assert_array_equal(call_args[0], input_ids)
|
|
902
|
+
|
|
903
|
+
# Verify padded indices length matches input_ids length
|
|
904
|
+
assert len(call_args[1]) == len(input_ids)
|
|
905
|
+
assert len(call_args[2]) == len(input_ids)
|
|
906
|
+
|
|
907
|
+
# Verify placeholder_num
|
|
908
|
+
assert call_args[4] == 2 # Number of actual substitutions
|
|
909
|
+
|
|
910
|
+
def test_prepare_inputs_routing_to_dp(self):
|
|
911
|
+
"""Test _prepare_inputs routes to _prepare_inputs_dp when dp_size > 1."""
|
|
912
|
+
|
|
913
|
+
# Bind the actual _prepare_inputs method
|
|
914
|
+
self.runner._prepare_inputs = TPUModelRunner._prepare_inputs.__get__(
|
|
915
|
+
self.runner)
|
|
916
|
+
|
|
917
|
+
self.runner.dp_size = 2
|
|
918
|
+
self.runner._prepare_inputs_dp = MagicMock(return_value=(None, None,
|
|
919
|
+
None, None,
|
|
920
|
+
None, None))
|
|
921
|
+
|
|
922
|
+
scheduler_output = MagicMock()
|
|
923
|
+
self.runner._prepare_inputs(scheduler_output)
|
|
924
|
+
|
|
925
|
+
# Verify _prepare_inputs_dp was called
|
|
926
|
+
self.runner._prepare_inputs_dp.assert_called_once_with(
|
|
927
|
+
scheduler_output)
|
|
928
|
+
|
|
929
|
+
def test_prepare_inputs_routing_to_non_dp(self):
|
|
930
|
+
"""Test _prepare_inputs routes to _prepare_inputs_non_dp when dp_size == 1."""
|
|
931
|
+
|
|
932
|
+
# Bind the actual _prepare_inputs method
|
|
933
|
+
self.runner._prepare_inputs = TPUModelRunner._prepare_inputs.__get__(
|
|
934
|
+
self.runner)
|
|
935
|
+
|
|
936
|
+
self.runner.dp_size = 1
|
|
937
|
+
self.runner._prepare_inputs_non_dp = MagicMock(
|
|
938
|
+
return_value=(None, None, None, None, None, None, None))
|
|
939
|
+
|
|
940
|
+
scheduler_output = MagicMock()
|
|
941
|
+
self.runner._prepare_inputs(scheduler_output)
|
|
942
|
+
|
|
943
|
+
# Verify _prepare_inputs_non_dp was called
|
|
944
|
+
self.runner._prepare_inputs_non_dp.assert_called_once_with(
|
|
945
|
+
scheduler_output)
|
|
946
|
+
|
|
947
|
+
@patch('tpu_inference.runner.tpu_runner.NamedSharding')
|
|
948
|
+
@patch('tpu_inference.runner.tpu_runner.runner_utils')
|
|
949
|
+
@patch('tpu_inference.runner.tpu_runner.device_array',
|
|
950
|
+
side_effect=lambda mesh, tensors, **kwargs: tensors)
|
|
951
|
+
@patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
|
|
952
|
+
def test_prepare_inputs_dp_with_async_scheduling(self,
|
|
953
|
+
mock_sampling_metadata,
|
|
954
|
+
mock_device_array,
|
|
955
|
+
mock_runner_utils,
|
|
956
|
+
mock_named_sharding):
|
|
957
|
+
|
|
958
|
+
# Setup mocking
|
|
959
|
+
def mock_get_padded_token_len(paddings_list, val):
|
|
960
|
+
if val <= 2:
|
|
961
|
+
return 4
|
|
962
|
+
elif val <= 5:
|
|
963
|
+
return 8
|
|
964
|
+
else:
|
|
965
|
+
return 16
|
|
966
|
+
|
|
967
|
+
mock_runner_utils.get_padded_token_len.side_effect = mock_get_padded_token_len
|
|
968
|
+
mock_sampling_instance = MagicMock()
|
|
969
|
+
mock_sampling_metadata.from_input_batch.return_value = mock_sampling_instance
|
|
970
|
+
mock_named_sharding.return_value = MagicMock()
|
|
971
|
+
|
|
972
|
+
# Setup test data
|
|
973
|
+
num_scheduled_tokens = {"req1": 3, "req2": 2}
|
|
974
|
+
assigned_dp_ranks = {"req1": 0, "req2": 1}
|
|
975
|
+
|
|
976
|
+
self.runner.input_batch.num_reqs = 2
|
|
977
|
+
self.runner.input_batch.req_ids = ["req1", "req2"]
|
|
978
|
+
self.runner.input_batch.num_computed_tokens_cpu = np.array([4, 6])
|
|
979
|
+
self.runner.input_batch.token_ids_cpu = np.zeros((8, 64),
|
|
980
|
+
dtype=np.int32)
|
|
981
|
+
|
|
982
|
+
scheduler_output = self._create_mock_scheduler_output(
|
|
983
|
+
num_scheduled_tokens, assigned_dp_ranks)
|
|
984
|
+
|
|
985
|
+
# Enable async scheduling
|
|
986
|
+
self.runner.scheduler_config.async_scheduling = True
|
|
987
|
+
self.runner._pre_async_results = MagicMock()
|
|
988
|
+
self.runner._pre_async_results.placeholder_req_id_to_index = {
|
|
989
|
+
"req1": 0
|
|
990
|
+
}
|
|
991
|
+
self.runner._pre_async_results.next_tokens = np.array([100])
|
|
992
|
+
|
|
993
|
+
# Setup required attributes
|
|
994
|
+
self.runner.uses_mrope = False
|
|
995
|
+
self.runner.phase_based_profiler = None
|
|
996
|
+
self.runner.lora_config = None
|
|
997
|
+
self.runner.mesh = MagicMock()
|
|
998
|
+
self.runner.data_parallel_sharding = MagicMock()
|
|
999
|
+
self.runner.data_parallel_attn_sharding = MagicMock()
|
|
1000
|
+
self.runner.mm_manager = MagicMock()
|
|
1001
|
+
self.runner.speculative_decoding_manager = MagicMock()
|
|
1002
|
+
self.runner.lora_utils = MagicMock()
|
|
1003
|
+
|
|
1004
|
+
# Mock the token substitution preparation
|
|
1005
|
+
mock_prepare_async = MagicMock(return_value=({
|
|
1006
|
+
0: [2],
|
|
1007
|
+
1: []
|
|
1008
|
+
}, {
|
|
1009
|
+
0: [0],
|
|
1010
|
+
1: []
|
|
1011
|
+
}))
|
|
1012
|
+
self.runner._prepare_async_token_substitution_indices_dp = mock_prepare_async
|
|
1013
|
+
|
|
1014
|
+
# Execute the method
|
|
1015
|
+
_ = self.runner._prepare_inputs_dp(scheduler_output)
|
|
1016
|
+
|
|
1017
|
+
# Verify async token substitution was called
|
|
1018
|
+
mock_prepare_async.assert_called_once()
|
|
1019
|
+
|
|
1020
|
+
@patch('tpu_inference.runner.tpu_runner.NamedSharding')
|
|
1021
|
+
@patch('tpu_inference.runner.tpu_runner.runner_utils')
|
|
1022
|
+
@patch('tpu_inference.runner.tpu_runner.device_array',
|
|
1023
|
+
side_effect=lambda mesh, tensors, **kwargs: tensors)
|
|
1024
|
+
@patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
|
|
1025
|
+
def test_prepare_inputs_dp_async_token_substitution_application(
|
|
1026
|
+
self, mock_sampling_metadata, mock_device_array, mock_runner_utils,
|
|
1027
|
+
mock_named_sharding):
|
|
1028
|
+
"""Test async token substitution application in DP mode."""
|
|
1029
|
+
|
|
1030
|
+
# Setup mocking
|
|
1031
|
+
def mock_get_padded_token_len(paddings_list, val):
|
|
1032
|
+
if val <= 2:
|
|
1033
|
+
return 4
|
|
1034
|
+
elif val <= 5:
|
|
1035
|
+
return 8
|
|
1036
|
+
else:
|
|
1037
|
+
return 16
|
|
1038
|
+
|
|
1039
|
+
mock_runner_utils.get_padded_token_len.side_effect = mock_get_padded_token_len
|
|
1040
|
+
mock_sampling_instance = MagicMock()
|
|
1041
|
+
mock_sampling_metadata.from_input_batch.return_value = mock_sampling_instance
|
|
1042
|
+
mock_named_sharding.return_value = MagicMock()
|
|
1043
|
+
|
|
1044
|
+
# Setup test data
|
|
1045
|
+
num_scheduled_tokens = {"req1": 3, "req2": 2}
|
|
1046
|
+
assigned_dp_ranks = {"req1": 0, "req2": 1}
|
|
1047
|
+
|
|
1048
|
+
self.runner.input_batch.num_reqs = 2
|
|
1049
|
+
self.runner.input_batch.req_ids = ["req1", "req2"]
|
|
1050
|
+
self.runner.input_batch.num_computed_tokens_cpu = np.array([4, 6])
|
|
1051
|
+
self.runner.input_batch.token_ids_cpu = np.zeros((8, 64),
|
|
1052
|
+
dtype=np.int32)
|
|
1053
|
+
|
|
1054
|
+
scheduler_output = self._create_mock_scheduler_output(
|
|
1055
|
+
num_scheduled_tokens, assigned_dp_ranks)
|
|
1056
|
+
|
|
1057
|
+
# Enable async scheduling with placeholders
|
|
1058
|
+
self.runner.scheduler_config.async_scheduling = True
|
|
1059
|
+
self.runner._pre_async_results = MagicMock()
|
|
1060
|
+
self.runner._pre_async_results.placeholder_req_id_to_index = {
|
|
1061
|
+
"req1": 0,
|
|
1062
|
+
"req2": 1
|
|
1063
|
+
}
|
|
1064
|
+
self.runner._pre_async_results.next_tokens = np.array([100, 200])
|
|
1065
|
+
|
|
1066
|
+
# Setup required attributes
|
|
1067
|
+
self.runner.uses_mrope = False
|
|
1068
|
+
self.runner.phase_based_profiler = None
|
|
1069
|
+
self.runner.lora_config = None
|
|
1070
|
+
self.runner.mesh = MagicMock()
|
|
1071
|
+
self.runner.data_parallel_sharding = MagicMock()
|
|
1072
|
+
self.runner.data_parallel_attn_sharding = MagicMock()
|
|
1073
|
+
self.runner.mm_manager = MagicMock()
|
|
1074
|
+
self.runner.speculative_decoding_manager = MagicMock()
|
|
1075
|
+
self.runner.lora_utils = MagicMock()
|
|
1076
|
+
|
|
1077
|
+
# Mock the async token substitution application
|
|
1078
|
+
mock_apply_async = MagicMock(
|
|
1079
|
+
return_value=np.array([1, 2, 100, 4, 5, 200, 7, 8]))
|
|
1080
|
+
self.runner._apply_async_token_substitution = mock_apply_async
|
|
1081
|
+
|
|
1082
|
+
# Execute the method
|
|
1083
|
+
_ = self.runner._prepare_inputs_dp(scheduler_output)
|
|
1084
|
+
|
|
1085
|
+
# Verify _apply_async_token_substitution was called
|
|
1086
|
+
mock_apply_async.assert_called_once()
|
|
1087
|
+
call_args = mock_apply_async.call_args[0]
|
|
1088
|
+
|
|
1089
|
+
# Verify indices were concatenated from both DP ranks
|
|
1090
|
+
token_in_tpu_cur_input_indices = call_args[1]
|
|
1091
|
+
token_in_tpu_pre_next_tokens_indices = call_args[2]
|
|
1092
|
+
|
|
1093
|
+
# Should have indices from both ranks
|
|
1094
|
+
assert len(token_in_tpu_cur_input_indices) == 2
|
|
1095
|
+
assert len(token_in_tpu_pre_next_tokens_indices) == 2
|
|
1096
|
+
|
|
1097
|
+
|
|
1098
|
+
if __name__ == "__main__":
|
|
1099
|
+
pytest.main([__file__])
|