tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +14 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +25 -8
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +14 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +20 -3
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +20 -26
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +22 -3
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +100 -455
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
- tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +37 -16
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +113 -124
- tpu_inference/models/jax/gpt_oss.py +23 -7
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
- tpu_inference/models/jax/utils/weight_utils.py +32 -1
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +27 -29
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +69 -35
- tpu_inference/runner/kv_cache.py +14 -0
- tpu_inference/runner/kv_cache_manager.py +15 -2
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +30 -10
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +31 -30
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +23 -7
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -208
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# test_block_table_jax.py
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import numpy as np
|
|
20
|
+
import pytest
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def cdiv(a: int, b: int) -> int:
|
|
24
|
+
"""Ceiling division: (a + b - 1) // b."""
|
|
25
|
+
return (a + b - 1) // b
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class BlockTable:
|
|
29
|
+
"""A JAX-compatible BlockTable for managing memory blocks."""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
max_num_reqs: int,
|
|
34
|
+
max_num_blocks_per_req: int,
|
|
35
|
+
max_num_batched_tokens: int,
|
|
36
|
+
pin_memory: bool, # Note: pin_memory is not used in JAX
|
|
37
|
+
):
|
|
38
|
+
self.max_num_reqs = max_num_reqs
|
|
39
|
+
self.max_num_blocks_per_req = max_num_blocks_per_req
|
|
40
|
+
self.block_table = jnp.zeros((max_num_reqs, max_num_blocks_per_req),
|
|
41
|
+
dtype=jnp.int32)
|
|
42
|
+
self.block_table_cpu = np.zeros((max_num_reqs, max_num_blocks_per_req),
|
|
43
|
+
dtype=np.int32)
|
|
44
|
+
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
|
45
|
+
|
|
46
|
+
def append_row(self, block_ids: list[int], row_idx: int) -> None:
|
|
47
|
+
if not block_ids:
|
|
48
|
+
return
|
|
49
|
+
num_blocks = len(block_ids)
|
|
50
|
+
start = self.num_blocks_per_row[row_idx]
|
|
51
|
+
self.num_blocks_per_row[row_idx] += num_blocks
|
|
52
|
+
self.block_table_cpu[row_idx, start:start + num_blocks] = block_ids
|
|
53
|
+
|
|
54
|
+
def add_row(self, block_ids: list[int], row_idx: int) -> None:
|
|
55
|
+
self.num_blocks_per_row[row_idx] = 0
|
|
56
|
+
# Clear the row for a clean overwrite
|
|
57
|
+
self.block_table_cpu[row_idx].fill(0)
|
|
58
|
+
self.append_row(block_ids, row_idx)
|
|
59
|
+
|
|
60
|
+
def move_row(self, src: int, tgt: int) -> None:
|
|
61
|
+
num_blocks = self.num_blocks_per_row[src]
|
|
62
|
+
self.block_table_cpu[tgt, :num_blocks] = self.block_table_cpu[
|
|
63
|
+
src, :num_blocks]
|
|
64
|
+
# Clear the rest of the target row to avoid stale data
|
|
65
|
+
self.block_table_cpu[tgt, num_blocks:].fill(0)
|
|
66
|
+
self.num_blocks_per_row[tgt] = num_blocks
|
|
67
|
+
|
|
68
|
+
def swap_row(self, src: int, tgt: int) -> None:
|
|
69
|
+
self.num_blocks_per_row[[src,
|
|
70
|
+
tgt]] = self.num_blocks_per_row[[tgt, src]]
|
|
71
|
+
self.block_table_cpu[[src, tgt]] = self.block_table_cpu[[tgt, src]]
|
|
72
|
+
|
|
73
|
+
def commit(self, num_reqs: int) -> None:
|
|
74
|
+
"""Corrected commit for JAX immutability."""
|
|
75
|
+
self.block_table = self.block_table.at[:num_reqs].set(
|
|
76
|
+
self.block_table_cpu[:num_reqs])
|
|
77
|
+
|
|
78
|
+
def clear(self) -> None:
|
|
79
|
+
"""Corrected clear for JAX immutability and completeness."""
|
|
80
|
+
self.block_table = jnp.zeros_like(self.block_table)
|
|
81
|
+
self.block_table_cpu.fill(0)
|
|
82
|
+
self.num_blocks_per_row.fill(0)
|
|
83
|
+
|
|
84
|
+
def get_device_tensor(self) -> jax.Array:
|
|
85
|
+
return self.block_table
|
|
86
|
+
|
|
87
|
+
def get_cpu_tensor(self) -> np.ndarray:
|
|
88
|
+
return self.block_table_cpu
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class MultiGroupBlockTable:
|
|
92
|
+
"""Manages BlockTables for each KV cache group."""
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
max_num_reqs: int,
|
|
97
|
+
max_model_len: int,
|
|
98
|
+
max_num_batched_tokens: int,
|
|
99
|
+
pin_memory: bool,
|
|
100
|
+
block_sizes: list[int],
|
|
101
|
+
) -> None:
|
|
102
|
+
self.block_tables = [
|
|
103
|
+
BlockTable(
|
|
104
|
+
max_num_reqs,
|
|
105
|
+
cdiv(max_model_len, block_size),
|
|
106
|
+
max_num_batched_tokens,
|
|
107
|
+
pin_memory,
|
|
108
|
+
) for block_size in block_sizes
|
|
109
|
+
]
|
|
110
|
+
|
|
111
|
+
def append_row(self, block_ids: list[list[int]], row_idx: int) -> None:
|
|
112
|
+
for i, block_table in enumerate(self.block_tables):
|
|
113
|
+
block_table.append_row(block_ids[i], row_idx)
|
|
114
|
+
|
|
115
|
+
def add_row(self, block_ids: list[list[int]], row_idx: int) -> None:
|
|
116
|
+
for i, block_table in enumerate(self.block_tables):
|
|
117
|
+
block_table.add_row(block_ids[i], row_idx)
|
|
118
|
+
|
|
119
|
+
def move_row(self, src: int, tgt: int) -> None:
|
|
120
|
+
for block_table in self.block_tables:
|
|
121
|
+
block_table.move_row(src, tgt)
|
|
122
|
+
|
|
123
|
+
def swap_row(self, src: int, tgt: int) -> None:
|
|
124
|
+
for block_table in self.block_tables:
|
|
125
|
+
block_table.swap_row(src, tgt)
|
|
126
|
+
|
|
127
|
+
def commit(self, num_reqs: int) -> None:
|
|
128
|
+
for block_table in self.block_tables:
|
|
129
|
+
block_table.commit(num_reqs)
|
|
130
|
+
|
|
131
|
+
def clear(self) -> None:
|
|
132
|
+
for block_table in self.block_tables:
|
|
133
|
+
block_table.clear()
|
|
134
|
+
|
|
135
|
+
def __getitem__(self, idx: int) -> "BlockTable":
|
|
136
|
+
return self.block_tables[idx]
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
# --- Pytest Fixtures ---
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@pytest.fixture
|
|
143
|
+
def block_table_params():
|
|
144
|
+
"""Provides common parameters for creating a BlockTable."""
|
|
145
|
+
return {
|
|
146
|
+
"max_num_reqs": 8,
|
|
147
|
+
"max_num_blocks_per_req": 16,
|
|
148
|
+
"max_num_batched_tokens": 8 * 16,
|
|
149
|
+
"pin_memory": False,
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@pytest.fixture
|
|
154
|
+
def block_table(block_table_params):
|
|
155
|
+
"""Provides a fresh BlockTable instance for each test."""
|
|
156
|
+
return BlockTable(**block_table_params)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
# --- Test Cases ---
|
|
160
|
+
|
|
161
|
+
##
|
|
162
|
+
## BlockTable Tests
|
|
163
|
+
##
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class TestBlockTable:
|
|
167
|
+
"""Tests for the single BlockTable class."""
|
|
168
|
+
|
|
169
|
+
def test_init(self, block_table, block_table_params):
|
|
170
|
+
"""Test constructor and initial state."""
|
|
171
|
+
bt = block_table
|
|
172
|
+
params = block_table_params
|
|
173
|
+
|
|
174
|
+
assert bt.max_num_reqs == params["max_num_reqs"]
|
|
175
|
+
assert bt.max_num_blocks_per_req == params["max_num_blocks_per_req"]
|
|
176
|
+
|
|
177
|
+
# Check CPU table
|
|
178
|
+
assert bt.block_table_cpu.shape == (
|
|
179
|
+
params["max_num_reqs"],
|
|
180
|
+
params["max_num_blocks_per_req"],
|
|
181
|
+
)
|
|
182
|
+
assert bt.block_table_cpu.dtype == np.int32
|
|
183
|
+
np.testing.assert_array_equal(bt.block_table_cpu, 0)
|
|
184
|
+
|
|
185
|
+
# Check device table
|
|
186
|
+
assert bt.block_table.shape == (
|
|
187
|
+
params["max_num_reqs"],
|
|
188
|
+
params["max_num_blocks_per_req"],
|
|
189
|
+
)
|
|
190
|
+
assert bt.block_table.dtype == jnp.int32
|
|
191
|
+
np.testing.assert_array_equal(np.array(bt.block_table), 0)
|
|
192
|
+
|
|
193
|
+
# Check block counter per row
|
|
194
|
+
assert bt.num_blocks_per_row.shape == (params["max_num_reqs"], )
|
|
195
|
+
np.testing.assert_array_equal(bt.num_blocks_per_row, 0)
|
|
196
|
+
|
|
197
|
+
def test_add_and_append_row(self, block_table):
|
|
198
|
+
"""Test adding and appending blocks to a row."""
|
|
199
|
+
# Append to row 0
|
|
200
|
+
block_table.append_row([1, 2, 3], row_idx=0)
|
|
201
|
+
assert block_table.num_blocks_per_row[0] == 3
|
|
202
|
+
np.testing.assert_array_equal(block_table.block_table_cpu[0, :3],
|
|
203
|
+
[1, 2, 3])
|
|
204
|
+
|
|
205
|
+
# Append more to row 0
|
|
206
|
+
block_table.append_row([4, 5], row_idx=0)
|
|
207
|
+
assert block_table.num_blocks_per_row[0] == 5
|
|
208
|
+
np.testing.assert_array_equal(block_table.block_table_cpu[0, :5],
|
|
209
|
+
[1, 2, 3, 4, 5])
|
|
210
|
+
|
|
211
|
+
# Add (overwrite) row 1
|
|
212
|
+
block_table.add_row([10, 11], row_idx=1)
|
|
213
|
+
assert block_table.num_blocks_per_row[1] == 2
|
|
214
|
+
np.testing.assert_array_equal(block_table.block_table_cpu[1, :2],
|
|
215
|
+
[10, 11])
|
|
216
|
+
|
|
217
|
+
# Add (overwrite) row 0
|
|
218
|
+
block_table.add_row([6, 7, 8, 9], row_idx=0)
|
|
219
|
+
assert block_table.num_blocks_per_row[0] == 4
|
|
220
|
+
np.testing.assert_array_equal(block_table.block_table_cpu[0, :4],
|
|
221
|
+
[6, 7, 8, 9])
|
|
222
|
+
assert block_table.block_table_cpu[
|
|
223
|
+
0, 4] == 0 # Ensure rest of row is clear
|
|
224
|
+
|
|
225
|
+
def test_move_row(self, block_table):
|
|
226
|
+
"""Test moving a row's content."""
|
|
227
|
+
block_table.add_row([10, 20, 30], row_idx=2)
|
|
228
|
+
block_table.add_row([99], row_idx=5) # Pre-existing data
|
|
229
|
+
|
|
230
|
+
block_table.move_row(src=2, tgt=5)
|
|
231
|
+
|
|
232
|
+
# Check target row
|
|
233
|
+
assert block_table.num_blocks_per_row[5] == 3
|
|
234
|
+
np.testing.assert_array_equal(block_table.get_cpu_tensor()[5, :3],
|
|
235
|
+
[10, 20, 30])
|
|
236
|
+
assert block_table.get_cpu_tensor()[
|
|
237
|
+
5, 3] == 0 # Check old data is cleared
|
|
238
|
+
|
|
239
|
+
# Check source row (should be unchanged)
|
|
240
|
+
assert block_table.num_blocks_per_row[2] == 3
|
|
241
|
+
np.testing.assert_array_equal(block_table.get_cpu_tensor()[2, :3],
|
|
242
|
+
[10, 20, 30])
|
|
243
|
+
|
|
244
|
+
def test_swap_row(self, block_table):
|
|
245
|
+
"""Test swapping two rows."""
|
|
246
|
+
row_2_data = [10, 20, 30]
|
|
247
|
+
row_5_data = [99, 88]
|
|
248
|
+
block_table.add_row(row_2_data, row_idx=2)
|
|
249
|
+
block_table.add_row(row_5_data, row_idx=5)
|
|
250
|
+
|
|
251
|
+
block_table.swap_row(src=2, tgt=5)
|
|
252
|
+
|
|
253
|
+
# Check that data and counts are swapped
|
|
254
|
+
assert block_table.num_blocks_per_row[2] == 2
|
|
255
|
+
assert block_table.num_blocks_per_row[5] == 3
|
|
256
|
+
np.testing.assert_array_equal(block_table.block_table_cpu[2, :2],
|
|
257
|
+
row_5_data)
|
|
258
|
+
np.testing.assert_array_equal(block_table.block_table_cpu[5, :3],
|
|
259
|
+
row_2_data)
|
|
260
|
+
|
|
261
|
+
def test_commit(self, block_table):
|
|
262
|
+
"""Test committing the CPU table to the JAX device table."""
|
|
263
|
+
block_table.add_row([1, 2, 3], row_idx=0)
|
|
264
|
+
block_table.add_row([4, 5], row_idx=1)
|
|
265
|
+
num_reqs_to_commit = 2
|
|
266
|
+
|
|
267
|
+
# Before commit, device tensor is all zeros
|
|
268
|
+
np.testing.assert_array_equal(
|
|
269
|
+
np.array(block_table.get_device_tensor()), 0)
|
|
270
|
+
|
|
271
|
+
block_table.commit(num_reqs_to_commit)
|
|
272
|
+
device_table = np.array(block_table.get_device_tensor())
|
|
273
|
+
|
|
274
|
+
# After commit, device tensor should match committed part of CPU tensor
|
|
275
|
+
np.testing.assert_array_equal(
|
|
276
|
+
device_table[:num_reqs_to_commit],
|
|
277
|
+
block_table.get_cpu_tensor()[:num_reqs_to_commit],
|
|
278
|
+
)
|
|
279
|
+
# The rest of the device tensor should still be zero
|
|
280
|
+
np.testing.assert_array_equal(device_table[num_reqs_to_commit:], 0)
|
|
281
|
+
|
|
282
|
+
def test_clear(self, block_table):
|
|
283
|
+
"""Test clearing all table data."""
|
|
284
|
+
block_table.add_row([1, 2, 3], row_idx=0)
|
|
285
|
+
block_table.commit(num_reqs=1)
|
|
286
|
+
|
|
287
|
+
# Pre-clear check
|
|
288
|
+
assert np.any(block_table.get_cpu_tensor())
|
|
289
|
+
assert jnp.any(block_table.get_device_tensor())
|
|
290
|
+
assert np.any(block_table.num_blocks_per_row)
|
|
291
|
+
|
|
292
|
+
block_table.clear()
|
|
293
|
+
|
|
294
|
+
# Post-clear check
|
|
295
|
+
np.testing.assert_array_equal(block_table.get_cpu_tensor(), 0)
|
|
296
|
+
np.testing.assert_array_equal(
|
|
297
|
+
np.array(block_table.get_device_tensor()), 0)
|
|
298
|
+
np.testing.assert_array_equal(block_table.num_blocks_per_row, 0)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
# ------------------------------------
|
|
302
|
+
# MultiGroupBlockTable Tests
|
|
303
|
+
# ------------------------------------
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
class TestMultiGroupBlockTable:
|
|
307
|
+
"""Tests for the MultiGroupBlockTable class."""
|
|
308
|
+
|
|
309
|
+
@pytest.fixture
|
|
310
|
+
def multi_table_params(self):
|
|
311
|
+
return {
|
|
312
|
+
"max_num_reqs": 4,
|
|
313
|
+
"max_model_len": 32,
|
|
314
|
+
"max_num_batched_tokens": 4 * 32,
|
|
315
|
+
"pin_memory": False,
|
|
316
|
+
"block_sizes": [16, 8], # Two groups
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
@pytest.fixture
|
|
320
|
+
def multi_table(self, multi_table_params):
|
|
321
|
+
return MultiGroupBlockTable(**multi_table_params)
|
|
322
|
+
|
|
323
|
+
def test_init(self, multi_table, multi_table_params):
|
|
324
|
+
"""Test constructor and initial state of multiple tables."""
|
|
325
|
+
params = multi_table_params
|
|
326
|
+
assert len(multi_table.block_tables) == len(params["block_sizes"])
|
|
327
|
+
assert isinstance(multi_table[0], BlockTable)
|
|
328
|
+
assert isinstance(multi_table[1], BlockTable)
|
|
329
|
+
|
|
330
|
+
# Check that max_num_blocks_per_req is calculated correctly
|
|
331
|
+
assert multi_table[0].max_num_blocks_per_req == cdiv(
|
|
332
|
+
params["max_model_len"], params["block_sizes"][0]) # 32 / 16 = 2
|
|
333
|
+
assert multi_table[1].max_num_blocks_per_req == cdiv(
|
|
334
|
+
params["max_model_len"], params["block_sizes"][1]) # 32 / 8 = 4
|
|
335
|
+
|
|
336
|
+
def test_add_row(self, multi_table):
|
|
337
|
+
"""Test add_row across multiple tables."""
|
|
338
|
+
block_ids = [[101, 102], [201, 202, 203]]
|
|
339
|
+
multi_table.add_row(block_ids, row_idx=0)
|
|
340
|
+
|
|
341
|
+
# Check table 0
|
|
342
|
+
assert multi_table[0].num_blocks_per_row[0] == 2
|
|
343
|
+
np.testing.assert_array_equal(multi_table[0].get_cpu_tensor()[0, :2],
|
|
344
|
+
block_ids[0])
|
|
345
|
+
|
|
346
|
+
# Check table 1
|
|
347
|
+
assert multi_table[1].num_blocks_per_row[0] == 3
|
|
348
|
+
np.testing.assert_array_equal(multi_table[1].get_cpu_tensor()[0, :3],
|
|
349
|
+
block_ids[1])
|
|
350
|
+
|
|
351
|
+
def test_swap_row(self, multi_table):
|
|
352
|
+
"""Test swap_row across multiple tables."""
|
|
353
|
+
row1_data = [[11], [11, 22]]
|
|
354
|
+
row3_data = [[33], [33, 44, 55]]
|
|
355
|
+
multi_table.add_row(row1_data, row_idx=1)
|
|
356
|
+
multi_table.add_row(row3_data, row_idx=3)
|
|
357
|
+
|
|
358
|
+
multi_table.swap_row(src=1, tgt=3)
|
|
359
|
+
|
|
360
|
+
# Check row 1 now has row 3's data
|
|
361
|
+
assert multi_table[0].num_blocks_per_row[1] == 1
|
|
362
|
+
np.testing.assert_array_equal(multi_table[0].get_cpu_tensor()[1, :1],
|
|
363
|
+
row3_data[0])
|
|
364
|
+
assert multi_table[1].num_blocks_per_row[1] == 3
|
|
365
|
+
np.testing.assert_array_equal(multi_table[1].get_cpu_tensor()[1, :3],
|
|
366
|
+
row3_data[1])
|
|
367
|
+
|
|
368
|
+
# Check row 3 now has row 1's data
|
|
369
|
+
assert multi_table[0].num_blocks_per_row[3] == 1
|
|
370
|
+
np.testing.assert_array_equal(multi_table[0].get_cpu_tensor()[3, :1],
|
|
371
|
+
row1_data[0])
|
|
372
|
+
assert multi_table[1].num_blocks_per_row[3] == 2
|
|
373
|
+
np.testing.assert_array_equal(multi_table[1].get_cpu_tensor()[3, :2],
|
|
374
|
+
row1_data[1])
|
|
375
|
+
|
|
376
|
+
def test_commit_and_clear(self, multi_table):
|
|
377
|
+
"""Test commit and clear across multiple tables."""
|
|
378
|
+
multi_table.add_row([[1], [1, 2]], row_idx=0)
|
|
379
|
+
multi_table.commit(num_reqs=1)
|
|
380
|
+
|
|
381
|
+
# Check commit worked for all tables
|
|
382
|
+
for table in multi_table.block_tables:
|
|
383
|
+
assert jnp.any(table.get_device_tensor())
|
|
384
|
+
device_table = np.array(table.get_device_tensor())
|
|
385
|
+
cpu_table = table.get_cpu_tensor()
|
|
386
|
+
np.testing.assert_array_equal(device_table, cpu_table)
|
|
387
|
+
|
|
388
|
+
multi_table.clear()
|
|
389
|
+
|
|
390
|
+
# Check clear worked for all tables
|
|
391
|
+
for table in multi_table.block_tables:
|
|
392
|
+
np.testing.assert_array_equal(table.get_cpu_tensor(), 0)
|
|
393
|
+
np.testing.assert_array_equal(np.array(table.get_device_tensor()),
|
|
394
|
+
0)
|
|
395
|
+
np.testing.assert_array_equal(table.num_blocks_per_row, 0)
|
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import pytest
|
|
17
|
+
from vllm.sampling_params import SamplingParams
|
|
18
|
+
|
|
19
|
+
from tpu_inference.runner.input_batch import CachedRequestState, InputBatch
|
|
20
|
+
|
|
21
|
+
# Default parameters for creating InputBatch instances in tests
|
|
22
|
+
MAX_NUM_REQS = 8
|
|
23
|
+
MAX_MODEL_LEN = 1024
|
|
24
|
+
MAX_NUM_BATCHED_TOKENS = 2048
|
|
25
|
+
VOCAB_SIZE = 32000
|
|
26
|
+
BLOCK_SIZES = [16]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@pytest.fixture
|
|
30
|
+
def input_batch():
|
|
31
|
+
"""Provides a clean InputBatch instance for each test."""
|
|
32
|
+
return InputBatch(
|
|
33
|
+
max_num_reqs=MAX_NUM_REQS,
|
|
34
|
+
max_model_len=MAX_MODEL_LEN,
|
|
35
|
+
max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS,
|
|
36
|
+
pin_memory=False,
|
|
37
|
+
vocab_size=VOCAB_SIZE,
|
|
38
|
+
block_sizes=BLOCK_SIZES,
|
|
39
|
+
is_spec_decode=True,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def create_dummy_request(req_id: str,
|
|
44
|
+
prompt_len: int = 10,
|
|
45
|
+
output_len: int = 5,
|
|
46
|
+
sampling_params: SamplingParams = None,
|
|
47
|
+
block_ids=None) -> CachedRequestState:
|
|
48
|
+
"""Helper function to create a CachedRequestState instance."""
|
|
49
|
+
if sampling_params is None:
|
|
50
|
+
sampling_params = SamplingParams(temperature=0.8, top_p=0.9, top_k=50)
|
|
51
|
+
|
|
52
|
+
prompt_token_ids = list(range(prompt_len))
|
|
53
|
+
output_token_ids = list(range(prompt_len, prompt_len + output_len))
|
|
54
|
+
|
|
55
|
+
if block_ids is None:
|
|
56
|
+
# Create dummy block ids based on length
|
|
57
|
+
num_blocks = (prompt_len + output_len + BLOCK_SIZES[0] -
|
|
58
|
+
1) // BLOCK_SIZES[0]
|
|
59
|
+
block_ids = [[i] for i in range(1, num_blocks + 1)]
|
|
60
|
+
|
|
61
|
+
return CachedRequestState(
|
|
62
|
+
req_id=req_id,
|
|
63
|
+
prompt_token_ids=prompt_token_ids,
|
|
64
|
+
mm_features=[],
|
|
65
|
+
sampling_params=sampling_params,
|
|
66
|
+
pooling_params=None,
|
|
67
|
+
block_ids=block_ids,
|
|
68
|
+
num_computed_tokens=0,
|
|
69
|
+
lora_request=None,
|
|
70
|
+
output_token_ids=output_token_ids,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def test_initialization(input_batch: InputBatch):
|
|
75
|
+
"""Tests if the InputBatch is initialized with correct default values."""
|
|
76
|
+
assert input_batch.max_num_reqs == MAX_NUM_REQS
|
|
77
|
+
assert input_batch.num_reqs == 0
|
|
78
|
+
assert len(input_batch.req_ids) == 0
|
|
79
|
+
assert not input_batch.req_id_to_index
|
|
80
|
+
assert input_batch.all_greedy
|
|
81
|
+
assert input_batch.is_spec_decode
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def test_add_request(input_batch: InputBatch):
|
|
85
|
+
"""Tests adding a single request to the batch."""
|
|
86
|
+
req = create_dummy_request("req-1", prompt_len=20, output_len=4)
|
|
87
|
+
input_batch.add_request(req)
|
|
88
|
+
|
|
89
|
+
assert input_batch.num_reqs == 1
|
|
90
|
+
assert "req-1" in input_batch.req_id_to_index
|
|
91
|
+
assert input_batch.req_id_to_index["req-1"] == 0
|
|
92
|
+
assert input_batch.req_ids == ["req-1"]
|
|
93
|
+
assert len(input_batch.spec_decode_unsupported_reqs) == 0
|
|
94
|
+
|
|
95
|
+
# Verify token data
|
|
96
|
+
assert input_batch.num_prompt_tokens[0] == 20
|
|
97
|
+
assert input_batch.num_tokens[0] == 24
|
|
98
|
+
assert input_batch.num_tokens_no_spec[0] == 24
|
|
99
|
+
expected_tokens = np.array(req.prompt_token_ids + req.output_token_ids)
|
|
100
|
+
np.testing.assert_array_equal(input_batch.token_ids_cpu[0, :24],
|
|
101
|
+
expected_tokens)
|
|
102
|
+
|
|
103
|
+
# Verify sampling params
|
|
104
|
+
assert input_batch.temperature_cpu[0] == 0.8
|
|
105
|
+
assert input_batch.top_p_cpu[0] == 0.9
|
|
106
|
+
assert input_batch.top_k_cpu[0] == 50
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def test_add_multiple_requests(input_batch: InputBatch):
|
|
110
|
+
"""Tests adding multiple requests and checks their indices."""
|
|
111
|
+
req1 = create_dummy_request("req-1")
|
|
112
|
+
req2 = create_dummy_request("req-2")
|
|
113
|
+
|
|
114
|
+
input_batch.add_request(req1)
|
|
115
|
+
input_batch.add_request(req2)
|
|
116
|
+
|
|
117
|
+
assert input_batch.num_reqs == 2
|
|
118
|
+
assert input_batch.req_ids == ["req-1", "req-2"]
|
|
119
|
+
assert input_batch.req_id_to_index["req-1"] == 0
|
|
120
|
+
assert input_batch.req_id_to_index["req-2"] == 1
|
|
121
|
+
assert input_batch.num_tokens[1] == len(req2.prompt_token_ids) + len(
|
|
122
|
+
req2.output_token_ids)
|
|
123
|
+
assert input_batch.num_tokens_no_spec[1] == len(
|
|
124
|
+
req2.prompt_token_ids) + len(req2.output_token_ids)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def test_remove_request(input_batch: InputBatch):
|
|
128
|
+
"""Tests removing a request, which leaves a gap in the batch."""
|
|
129
|
+
req1 = create_dummy_request("req-1")
|
|
130
|
+
req2 = create_dummy_request("req-2")
|
|
131
|
+
input_batch.add_request(req1)
|
|
132
|
+
input_batch.add_request(req2)
|
|
133
|
+
|
|
134
|
+
removed_index = input_batch.remove_request("req-1")
|
|
135
|
+
|
|
136
|
+
assert removed_index == 0
|
|
137
|
+
assert input_batch.num_reqs == 1
|
|
138
|
+
assert "req-1" not in input_batch.req_id_to_index
|
|
139
|
+
assert input_batch._req_ids[0] is None # Slot is now empty
|
|
140
|
+
assert input_batch._req_ids[1] == "req-2"
|
|
141
|
+
assert "req-1" not in input_batch.greedy_reqs
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def test_condense(input_batch: InputBatch):
|
|
145
|
+
"""Tests condensing the batch after removing requests."""
|
|
146
|
+
reqs = [create_dummy_request(f"req-{i}") for i in range(4)]
|
|
147
|
+
for req in reqs:
|
|
148
|
+
input_batch.add_request(req)
|
|
149
|
+
|
|
150
|
+
# Remove requests from the middle and start
|
|
151
|
+
input_batch.remove_request("req-1")
|
|
152
|
+
input_batch.remove_request("req-0")
|
|
153
|
+
|
|
154
|
+
# Before condense: [None, None, "req-2", "req-3"]
|
|
155
|
+
assert input_batch._req_ids[0] is None
|
|
156
|
+
assert input_batch._req_ids[1] is None
|
|
157
|
+
assert input_batch.num_reqs == 2
|
|
158
|
+
|
|
159
|
+
# Condense should move req-2 and req-3 to the front
|
|
160
|
+
empty_indices = sorted([0, 1], reverse=True)
|
|
161
|
+
input_batch.condense(empty_indices)
|
|
162
|
+
|
|
163
|
+
assert input_batch.num_reqs == 2
|
|
164
|
+
assert len(input_batch.req_ids) == 2
|
|
165
|
+
assert input_batch.req_ids == ["req-3", "req-2"]
|
|
166
|
+
assert input_batch.req_id_to_index["req-2"] == 1
|
|
167
|
+
assert input_batch.req_id_to_index["req-3"] == 0
|
|
168
|
+
|
|
169
|
+
# Check if a property was moved correctly
|
|
170
|
+
assert input_batch.num_tokens[0] == len(reqs[2].prompt_token_ids) + len(
|
|
171
|
+
reqs[2].output_token_ids)
|
|
172
|
+
assert input_batch.num_tokens_no_spec[0] == len(
|
|
173
|
+
reqs[2].prompt_token_ids) + len(reqs[2].output_token_ids)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def test_swap_states(input_batch: InputBatch):
|
|
177
|
+
"""Tests swapping the states of two requests."""
|
|
178
|
+
req1 = create_dummy_request("req-1", prompt_len=10, output_len=1)
|
|
179
|
+
req2 = create_dummy_request("req-2",
|
|
180
|
+
prompt_len=20,
|
|
181
|
+
output_len=2,
|
|
182
|
+
sampling_params=SamplingParams(top_p=0.5))
|
|
183
|
+
|
|
184
|
+
input_batch.add_request(req1)
|
|
185
|
+
input_batch.add_request(req2)
|
|
186
|
+
|
|
187
|
+
# Capture states before swap
|
|
188
|
+
req1_tokens_before = input_batch.token_ids_cpu[0].copy()
|
|
189
|
+
req2_tokens_before = input_batch.token_ids_cpu[1].copy()
|
|
190
|
+
req1_top_p_before = input_batch.top_p_cpu[0]
|
|
191
|
+
req2_top_p_before = input_batch.top_p_cpu[1]
|
|
192
|
+
|
|
193
|
+
input_batch.swap_states(0, 1)
|
|
194
|
+
|
|
195
|
+
# Check IDs and mappings
|
|
196
|
+
assert input_batch.req_ids == ["req-2", "req-1"]
|
|
197
|
+
assert input_batch.req_id_to_index["req-1"] == 1
|
|
198
|
+
assert input_batch.req_id_to_index["req-2"] == 0
|
|
199
|
+
|
|
200
|
+
# Check swapped data
|
|
201
|
+
assert input_batch.top_p_cpu[0] == req2_top_p_before
|
|
202
|
+
assert input_batch.top_p_cpu[1] == req1_top_p_before
|
|
203
|
+
np.testing.assert_array_equal(input_batch.token_ids_cpu[0],
|
|
204
|
+
req2_tokens_before)
|
|
205
|
+
np.testing.assert_array_equal(input_batch.token_ids_cpu[1],
|
|
206
|
+
req1_tokens_before)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def test_all_greedy_property(input_batch: InputBatch):
|
|
210
|
+
"""Tests the `all_greedy` property."""
|
|
211
|
+
# Initially true
|
|
212
|
+
assert input_batch.all_greedy
|
|
213
|
+
|
|
214
|
+
# Add a greedy request, still true
|
|
215
|
+
req_greedy = create_dummy_request(
|
|
216
|
+
"req-g", sampling_params=SamplingParams(temperature=0.0))
|
|
217
|
+
input_batch.add_request(req_greedy)
|
|
218
|
+
assert input_batch.all_greedy
|
|
219
|
+
|
|
220
|
+
# Manually add a random request for testing purposes
|
|
221
|
+
input_batch.random_reqs.add("req-r")
|
|
222
|
+
assert not input_batch.all_greedy
|
|
223
|
+
|
|
224
|
+
# Remove it, should be true again
|
|
225
|
+
input_batch.random_reqs.remove("req-r")
|
|
226
|
+
assert input_batch.all_greedy
|