tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +14 -0
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +406 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +320 -0
- tests/layers/vllm/test_unquantized.py +662 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +25 -8
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +14 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +20 -3
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +20 -26
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/fused_moe_gmm.py +506 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +282 -0
- tpu_inference/layers/common/sharding.py +22 -3
- tpu_inference/layers/common/utils.py +94 -0
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +100 -455
- tpu_inference/layers/vllm/linear.py +64 -0
- tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
- tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
- tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
- tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
- tpu_inference/layers/vllm/quantization/__init__.py +19 -3
- tpu_inference/layers/vllm/quantization/awq.py +96 -82
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
- tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
- tpu_inference/layers/vllm/quantization/fp8.py +119 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
- tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +37 -16
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +113 -124
- tpu_inference/models/jax/gpt_oss.py +23 -7
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
- tpu_inference/models/jax/utils/weight_utils.py +32 -1
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +27 -29
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +69 -35
- tpu_inference/runner/kv_cache.py +14 -0
- tpu_inference/runner/kv_cache_manager.py +15 -2
- tpu_inference/runner/lora_utils.py +16 -1
- tpu_inference/runner/multimodal_manager.py +16 -2
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +30 -10
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +31 -30
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +23 -7
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
- tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
- tpu_inference/layers/vllm/linear_common.py +0 -208
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,368 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from unittest.mock import MagicMock, patch
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import numpy as np
|
|
19
|
+
import pytest
|
|
20
|
+
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
|
21
|
+
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
|
22
|
+
from vllm.sampling_params import SamplingType
|
|
23
|
+
from vllm.v1.outputs import DraftTokenIds
|
|
24
|
+
|
|
25
|
+
from tpu_inference.runner.input_batch import CachedRequestState, InputBatch
|
|
26
|
+
from tpu_inference.runner.speculative_decoding_manager import \
|
|
27
|
+
SpecDecodeMetadata
|
|
28
|
+
from tpu_inference.runner.tpu_runner import TPUModelRunner
|
|
29
|
+
from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TestSpeculativeDecodingManager:
|
|
33
|
+
|
|
34
|
+
def setup_method(self):
|
|
35
|
+
# Mock JAX dependencies
|
|
36
|
+
self.mock_devices = [MagicMock(coords=i) for i in range(1)]
|
|
37
|
+
device_array = np.array(jax.devices()[:1]).reshape(1, 1, 1, 1)
|
|
38
|
+
self.mock_mesh = jax.make_mesh(device_array.shape,
|
|
39
|
+
('data', 'attn_dp', 'expert', 'model'))
|
|
40
|
+
self.mock_rng_key = MagicMock()
|
|
41
|
+
|
|
42
|
+
with patch('jax.devices', return_value=self.mock_devices), \
|
|
43
|
+
patch('jax.make_mesh', return_value=self.mock_mesh), \
|
|
44
|
+
patch('jax.random.key', return_value=self.mock_rng_key), \
|
|
45
|
+
patch('tpu_inference.runner.tpu_runner.get_model', return_value=MagicMock()), \
|
|
46
|
+
patch('tpu_inference.runner.tpu_runner.make_optimized_mesh', return_value=self.mock_mesh):
|
|
47
|
+
|
|
48
|
+
model_config = ModelConfig(tokenizer_mode="auto",
|
|
49
|
+
trust_remote_code=False,
|
|
50
|
+
seed=0,
|
|
51
|
+
dtype='bfloat16')
|
|
52
|
+
cache_config = CacheConfig(
|
|
53
|
+
block_size=16,
|
|
54
|
+
gpu_memory_utilization=0.9,
|
|
55
|
+
swap_space=4,
|
|
56
|
+
cache_dtype="auto",
|
|
57
|
+
)
|
|
58
|
+
scheduler_config = SchedulerConfig(max_num_seqs=16,
|
|
59
|
+
max_model_len=1024,
|
|
60
|
+
is_encoder_decoder=False)
|
|
61
|
+
parallel_config = ParallelConfig(
|
|
62
|
+
pipeline_parallel_size=1,
|
|
63
|
+
tensor_parallel_size=1,
|
|
64
|
+
worker_use_ray=False,
|
|
65
|
+
)
|
|
66
|
+
speculative_config = SpeculativeConfig(
|
|
67
|
+
model='ngram',
|
|
68
|
+
num_speculative_tokens=5,
|
|
69
|
+
prompt_lookup_max=4,
|
|
70
|
+
)
|
|
71
|
+
vllm_config = VllmConfig(
|
|
72
|
+
model_config=model_config,
|
|
73
|
+
cache_config=cache_config,
|
|
74
|
+
scheduler_config=scheduler_config,
|
|
75
|
+
parallel_config=parallel_config,
|
|
76
|
+
speculative_config=speculative_config,
|
|
77
|
+
observability_config={},
|
|
78
|
+
additional_config={},
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
self.runner = TPUModelRunner(vllm_config,
|
|
82
|
+
devices=self.mock_devices)
|
|
83
|
+
|
|
84
|
+
def test_propose_draft_token_ids_dispatches_to_eagle(self):
|
|
85
|
+
"""Tests that propose_draft_token_ids calls the correct eagle method."""
|
|
86
|
+
# 1. ===== Setup =====
|
|
87
|
+
# Set the drafter to be an Eagle3Proposer
|
|
88
|
+
self.runner.drafter = MagicMock(spec=Eagle3Proposer)
|
|
89
|
+
self.runner.speculative_config.method = "eagle3"
|
|
90
|
+
|
|
91
|
+
# Mock the eagle-specific proposal method
|
|
92
|
+
with patch.object(self.runner.speculative_decoding_manager,
|
|
93
|
+
'propose_eagle3_draft_token_ids',
|
|
94
|
+
return_value=[[10, 11]]) as mock_propose_eagle:
|
|
95
|
+
|
|
96
|
+
# 2. ===== Act =====
|
|
97
|
+
self.runner.speculative_decoding_manager.propose_draft_token_ids(
|
|
98
|
+
sampled_token_ids=[[1]],
|
|
99
|
+
aux_hidden_states=None,
|
|
100
|
+
attn_metadata=MagicMock(),
|
|
101
|
+
spec_decode_metadata=None,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
# 3. ===== Assert =====
|
|
105
|
+
mock_propose_eagle.assert_called_once()
|
|
106
|
+
assert self.runner.speculative_decoding_manager._draft_token_ids == [
|
|
107
|
+
[10, 11]
|
|
108
|
+
]
|
|
109
|
+
|
|
110
|
+
def test_propose_draft_token_ids_wrong_drafter_type(self):
|
|
111
|
+
"""Tests that an assertion is raised if the drafter is not an NgramProposer."""
|
|
112
|
+
# The default drafter is NgramProposer, so we replace it with a generic mock
|
|
113
|
+
self.runner.drafter = MagicMock()
|
|
114
|
+
self.runner.speculative_config.method = "ngram"
|
|
115
|
+
with pytest.raises(AssertionError):
|
|
116
|
+
self.runner.speculative_decoding_manager.propose_draft_token_ids(
|
|
117
|
+
[[1]], None, MagicMock(), None)
|
|
118
|
+
|
|
119
|
+
def test_take_draft_token_ids(self):
|
|
120
|
+
"""Tests the take_draft_token_ids method for speculative decoding."""
|
|
121
|
+
# Case 1: No draft tokens are available.
|
|
122
|
+
self.runner.speculative_decoding_manager._draft_token_ids = None
|
|
123
|
+
result = self.runner.take_draft_token_ids()
|
|
124
|
+
assert result is None
|
|
125
|
+
|
|
126
|
+
# Case 2: Draft tokens are available.
|
|
127
|
+
mock_req_ids = ["req-1", "req-2"]
|
|
128
|
+
mock_draft_ids = [[10, 11], [20, 21, 22]]
|
|
129
|
+
|
|
130
|
+
# Re-initialize input_batch for a clean state for this specific test
|
|
131
|
+
self.runner.input_batch = InputBatch(
|
|
132
|
+
max_num_reqs=self.runner.max_num_reqs,
|
|
133
|
+
max_model_len=self.runner.max_model_len,
|
|
134
|
+
max_num_batched_tokens=self.runner.max_num_tokens,
|
|
135
|
+
pin_memory=False,
|
|
136
|
+
vocab_size=self.runner.vocab_size,
|
|
137
|
+
block_sizes=[self.runner.block_size],
|
|
138
|
+
is_spec_decode=True,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Add some requests to populate `input_batch.req_ids`
|
|
142
|
+
mock_sampling_params = MagicMock()
|
|
143
|
+
mock_sampling_params.sampling_type = SamplingType.GREEDY
|
|
144
|
+
mock_sampling_params.top_k = -1
|
|
145
|
+
mock_sampling_params.top_p = 1.0
|
|
146
|
+
mock_sampling_params.temperature = 0.0
|
|
147
|
+
mock_sampling_params.min_tokens = 0
|
|
148
|
+
mock_sampling_params.logprobs = None
|
|
149
|
+
mock_sampling_params.logit_bias = None
|
|
150
|
+
mock_sampling_params.allowed_token_ids = set()
|
|
151
|
+
mock_sampling_params.bad_words_token_ids = None
|
|
152
|
+
mock_sampling_params.all_stop_token_ids = set()
|
|
153
|
+
|
|
154
|
+
req1 = CachedRequestState(req_id="req-1",
|
|
155
|
+
prompt_token_ids=[1],
|
|
156
|
+
output_token_ids=[],
|
|
157
|
+
sampling_params=mock_sampling_params,
|
|
158
|
+
block_ids=([1], ),
|
|
159
|
+
num_computed_tokens=1,
|
|
160
|
+
lora_request=None,
|
|
161
|
+
mm_features=[],
|
|
162
|
+
pooling_params=None,
|
|
163
|
+
generator=None)
|
|
164
|
+
req2 = CachedRequestState(req_id="req-2",
|
|
165
|
+
prompt_token_ids=[2],
|
|
166
|
+
output_token_ids=[],
|
|
167
|
+
sampling_params=mock_sampling_params,
|
|
168
|
+
block_ids=([2], ),
|
|
169
|
+
num_computed_tokens=1,
|
|
170
|
+
lora_request=None,
|
|
171
|
+
mm_features=[],
|
|
172
|
+
pooling_params=None,
|
|
173
|
+
generator=None)
|
|
174
|
+
self.runner.input_batch.add_request(req1)
|
|
175
|
+
self.runner.input_batch.add_request(req2)
|
|
176
|
+
|
|
177
|
+
# Set the draft tokens to be taken
|
|
178
|
+
self.runner.speculative_decoding_manager._draft_token_ids = mock_draft_ids
|
|
179
|
+
|
|
180
|
+
# Call the method to be tested
|
|
181
|
+
result = self.runner.take_draft_token_ids()
|
|
182
|
+
|
|
183
|
+
# Assertions for the returned object
|
|
184
|
+
assert result is not None
|
|
185
|
+
assert isinstance(result, DraftTokenIds)
|
|
186
|
+
assert result.req_ids == mock_req_ids
|
|
187
|
+
assert result.draft_token_ids == mock_draft_ids
|
|
188
|
+
|
|
189
|
+
# Assert that the internal state is reset
|
|
190
|
+
assert self.runner.speculative_decoding_manager._draft_token_ids is None
|
|
191
|
+
|
|
192
|
+
# Case 3: Call again after taking, should return None
|
|
193
|
+
result_after = self.runner.take_draft_token_ids()
|
|
194
|
+
assert result_after is None
|
|
195
|
+
|
|
196
|
+
def _setup_spec_decode_metadata_test(self):
|
|
197
|
+
"""Helper method to set up common test infrastructure for spec decode metadata tests."""
|
|
198
|
+
# Mock runner attributes needed by the function
|
|
199
|
+
self.runner.arange_cpu = np.arange(1024, dtype=np.int64)
|
|
200
|
+
# Make input_ids_cpu a sequence of numbers for easy verification
|
|
201
|
+
self.runner.input_ids_cpu = np.arange(1024, dtype=np.int32) * 10
|
|
202
|
+
self.runner.num_tokens_paddings = [16, 32, 64, 128, 256, 512, 1024]
|
|
203
|
+
|
|
204
|
+
# Mock the device_array function to just return the numpy arrays
|
|
205
|
+
def mock_device_array(mesh, *args, **kwargs):
|
|
206
|
+
# Skip mesh parameter and return the actual arrays
|
|
207
|
+
if len(args) == 1 and isinstance(args[0], tuple):
|
|
208
|
+
return args[0]
|
|
209
|
+
return args
|
|
210
|
+
|
|
211
|
+
self.mock_device_array = mock_device_array
|
|
212
|
+
|
|
213
|
+
@pytest.mark.parametrize(
|
|
214
|
+
"num_draft_tokens,cu_num_scheduled_tokens,padded_num_reqs,expected_logits_indices,expected_bonus_logits_indices,expected_target_logits_indices,expected_draft_token_ids",
|
|
215
|
+
[
|
|
216
|
+
(
|
|
217
|
+
# Normal case
|
|
218
|
+
[3, 0, 2, 0, 1],
|
|
219
|
+
[4, 104, 107, 207, 209],
|
|
220
|
+
8,
|
|
221
|
+
[0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208],
|
|
222
|
+
[3, 4, 7, 8, 10, 0, 0, 0],
|
|
223
|
+
[0, 1, 2, 5, 6, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
|
224
|
+
[10, 20, 30, 1050, 1060, 2080, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
|
|
225
|
+
(
|
|
226
|
+
# High speculative tokens case
|
|
227
|
+
[5, 3, 4, 2, 1],
|
|
228
|
+
[6, 10, 18, 22, 26],
|
|
229
|
+
8,
|
|
230
|
+
[
|
|
231
|
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 13, 14, 15, 16, 17, 19, 20,
|
|
232
|
+
21, 24, 25
|
|
233
|
+
],
|
|
234
|
+
[5, 9, 14, 17, 19, 0, 0, 0],
|
|
235
|
+
[
|
|
236
|
+
0, 1, 2, 3, 4, 6, 7, 8, 10, 11, 12, 13, 15, 16, 18, 0, 0,
|
|
237
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
|
|
238
|
+
],
|
|
239
|
+
[
|
|
240
|
+
10, 20, 30, 40, 50, 70, 80, 90, 140, 150, 160, 170, 200,
|
|
241
|
+
210, 250
|
|
242
|
+
]),
|
|
243
|
+
])
|
|
244
|
+
def test_get_spec_decode_metadata_parametrized(
|
|
245
|
+
self, num_draft_tokens, cu_num_scheduled_tokens, padded_num_reqs,
|
|
246
|
+
expected_logits_indices, expected_bonus_logits_indices,
|
|
247
|
+
expected_target_logits_indices, expected_draft_token_ids):
|
|
248
|
+
"""Comprehensive parametrized test for _get_spec_decode_metadata function."""
|
|
249
|
+
# Setup
|
|
250
|
+
self._setup_spec_decode_metadata_test()
|
|
251
|
+
|
|
252
|
+
# Convert Python lists to numpy arrays for function input
|
|
253
|
+
num_draft_tokens_np = np.array(num_draft_tokens, dtype=np.int32)
|
|
254
|
+
cu_num_scheduled_tokens_np = np.array(cu_num_scheduled_tokens,
|
|
255
|
+
dtype=np.int32)
|
|
256
|
+
|
|
257
|
+
# Act
|
|
258
|
+
with patch(
|
|
259
|
+
"tpu_inference.runner.speculative_decoding_manager.device_array",
|
|
260
|
+
side_effect=self.mock_device_array):
|
|
261
|
+
metadata = self.runner.speculative_decoding_manager.get_spec_decode_metadata(
|
|
262
|
+
num_draft_tokens_np,
|
|
263
|
+
cu_num_scheduled_tokens_np,
|
|
264
|
+
padded_num_reqs=padded_num_reqs)
|
|
265
|
+
|
|
266
|
+
# Assert basic properties
|
|
267
|
+
assert isinstance(metadata, SpecDecodeMetadata)
|
|
268
|
+
|
|
269
|
+
# Determine padding length based on expected_logits_indices length
|
|
270
|
+
if len(expected_logits_indices) <= 16:
|
|
271
|
+
padded_len = 16
|
|
272
|
+
else:
|
|
273
|
+
padded_len = 32
|
|
274
|
+
|
|
275
|
+
# final_logits_indices - pad to bucket size and compare as Python lists
|
|
276
|
+
expected_padded_logits_indices = expected_logits_indices + [0] * (
|
|
277
|
+
padded_len - len(expected_logits_indices))
|
|
278
|
+
assert np.asarray(metadata.final_logits_indices).tolist(
|
|
279
|
+
) == expected_padded_logits_indices
|
|
280
|
+
|
|
281
|
+
# bonus_logits_indices - compare as Python lists
|
|
282
|
+
assert np.asarray(metadata.bonus_logits_indices).tolist(
|
|
283
|
+
) == expected_bonus_logits_indices
|
|
284
|
+
|
|
285
|
+
# target_logits_indices - pad to same length as final_logits_indices and compare as Python lists
|
|
286
|
+
expected_padded_target_logits_indices = expected_target_logits_indices + [
|
|
287
|
+
0
|
|
288
|
+
] * (padded_len - len(expected_target_logits_indices))
|
|
289
|
+
assert np.asarray(metadata.target_logits_indices).tolist(
|
|
290
|
+
) == expected_padded_target_logits_indices
|
|
291
|
+
|
|
292
|
+
# draft_token_ids - pad the expected values to the correct length and compare as Python lists
|
|
293
|
+
expected_padded_draft_token_ids = expected_draft_token_ids + [0] * (
|
|
294
|
+
padded_len - len(expected_draft_token_ids))
|
|
295
|
+
assert np.asarray(metadata.draft_token_ids).tolist(
|
|
296
|
+
) == expected_padded_draft_token_ids
|
|
297
|
+
|
|
298
|
+
# draft_lengths - pad and compare as Python lists
|
|
299
|
+
expected_padded_num_draft_tokens = num_draft_tokens + [0] * (
|
|
300
|
+
padded_num_reqs - len(num_draft_tokens))
|
|
301
|
+
assert np.asarray(metadata.draft_lengths).tolist(
|
|
302
|
+
) == expected_padded_num_draft_tokens
|
|
303
|
+
|
|
304
|
+
@pytest.mark.parametrize("spec_decode_metadata_is_none", [True, False])
|
|
305
|
+
def test_propose_eagle3_draft_token_ids(self,
|
|
306
|
+
spec_decode_metadata_is_none):
|
|
307
|
+
"""Tests the logic for proposing Eagle3 draft tokens."""
|
|
308
|
+
# 1. ===== Setup =====
|
|
309
|
+
self.runner.drafter = MagicMock(spec=Eagle3Proposer)
|
|
310
|
+
self.runner.speculative_config.method = "eagle3"
|
|
311
|
+
|
|
312
|
+
# Mock TPUModelRunner attributes
|
|
313
|
+
self.runner.input_batch = MagicMock()
|
|
314
|
+
self.runner.input_batch.req_ids = ["req-1", "req-2"]
|
|
315
|
+
self.runner.requests = {
|
|
316
|
+
"req-1": MagicMock(),
|
|
317
|
+
"req-2": MagicMock(),
|
|
318
|
+
}
|
|
319
|
+
self.runner.mesh = self.mock_mesh
|
|
320
|
+
self.runner.kv_caches = MagicMock()
|
|
321
|
+
|
|
322
|
+
# Mock drafter methods
|
|
323
|
+
mock_attn_metadata = MagicMock()
|
|
324
|
+
mock_target_token_ids = MagicMock()
|
|
325
|
+
mock_last_token_indices = MagicMock()
|
|
326
|
+
mock_target_hidden_states = MagicMock()
|
|
327
|
+
self.runner.drafter.prepare_inputs.return_value = (
|
|
328
|
+
mock_target_hidden_states,
|
|
329
|
+
mock_target_token_ids,
|
|
330
|
+
mock_last_token_indices,
|
|
331
|
+
mock_attn_metadata,
|
|
332
|
+
)
|
|
333
|
+
mock_draft_token_ids = [[10, 11], [20, 21]]
|
|
334
|
+
self.runner.drafter.propose.return_value = (
|
|
335
|
+
self.runner.kv_caches,
|
|
336
|
+
mock_draft_token_ids,
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
# Inputs
|
|
340
|
+
sampled_token_ids = [[1], [2]]
|
|
341
|
+
aux_hidden_states = MagicMock()
|
|
342
|
+
attn_metadata = MagicMock()
|
|
343
|
+
attn_metadata.seq_lens.shape = [2]
|
|
344
|
+
if spec_decode_metadata_is_none:
|
|
345
|
+
spec_decode_metadata = None
|
|
346
|
+
else:
|
|
347
|
+
spec_decode_metadata = MagicMock(spec=SpecDecodeMetadata)
|
|
348
|
+
spec_decode_metadata.draft_lengths_cpu = np.array([2, 3])
|
|
349
|
+
scheduler_output = MagicMock()
|
|
350
|
+
input_ids = MagicMock()
|
|
351
|
+
|
|
352
|
+
# 2. ===== Act =====
|
|
353
|
+
with patch(
|
|
354
|
+
"tpu_inference.runner.speculative_decoding_manager.device_array",
|
|
355
|
+
side_effect=lambda mesh, x: x):
|
|
356
|
+
result = self.runner.speculative_decoding_manager.propose_eagle3_draft_token_ids(
|
|
357
|
+
sampled_token_ids,
|
|
358
|
+
aux_hidden_states,
|
|
359
|
+
attn_metadata,
|
|
360
|
+
spec_decode_metadata,
|
|
361
|
+
scheduler_output,
|
|
362
|
+
input_ids,
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
# 3. ===== Assert =====
|
|
366
|
+
assert result == [[10, 11], [20, 21]]
|
|
367
|
+
self.runner.drafter.prepare_inputs.assert_called_once()
|
|
368
|
+
self.runner.drafter.propose.assert_called_once()
|
|
@@ -0,0 +1,220 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from unittest.mock import MagicMock, patch
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import jax.numpy as jnp
|
|
19
|
+
import numpy as np
|
|
20
|
+
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
|
21
|
+
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
|
22
|
+
from vllm.sampling_params import SamplingType
|
|
23
|
+
|
|
24
|
+
from tpu_inference.runner.input_batch import CachedRequestState
|
|
25
|
+
from tpu_inference.runner.tpu_runner import TPUModelRunner
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class TestStructuredDecodingManager:
|
|
29
|
+
|
|
30
|
+
def setup_method(self):
|
|
31
|
+
# Mock JAX dependencies
|
|
32
|
+
self.mock_rng_key = MagicMock()
|
|
33
|
+
self.mock_devices = [MagicMock(coords=i) for i in range(1)]
|
|
34
|
+
device_array = np.array(jax.devices()[:1]).reshape(1, 1, 1, 1)
|
|
35
|
+
self.mock_mesh = jax.make_mesh(device_array.shape,
|
|
36
|
+
('data', 'attn_dp', 'expert', 'model'))
|
|
37
|
+
self.mock_rng_key = MagicMock()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
with patch('jax.devices', return_value=self.mock_devices), \
|
|
41
|
+
patch('jax.make_mesh', return_value=self.mock_mesh), \
|
|
42
|
+
patch('jax.random.key', return_value=self.mock_rng_key), \
|
|
43
|
+
patch('tpu_inference.runner.tpu_runner.get_model', return_value=MagicMock()), \
|
|
44
|
+
patch('tpu_inference.runner.tpu_runner.make_optimized_mesh', return_value=self.mock_mesh):
|
|
45
|
+
|
|
46
|
+
model_config = ModelConfig(tokenizer_mode="auto",
|
|
47
|
+
trust_remote_code=False,
|
|
48
|
+
seed=0,
|
|
49
|
+
dtype='bfloat16')
|
|
50
|
+
cache_config = CacheConfig(
|
|
51
|
+
block_size=16,
|
|
52
|
+
gpu_memory_utilization=0.9,
|
|
53
|
+
swap_space=4,
|
|
54
|
+
cache_dtype="auto",
|
|
55
|
+
)
|
|
56
|
+
scheduler_config = SchedulerConfig(max_num_seqs=16,
|
|
57
|
+
max_model_len=1024,
|
|
58
|
+
is_encoder_decoder=False)
|
|
59
|
+
parallel_config = ParallelConfig(
|
|
60
|
+
pipeline_parallel_size=1,
|
|
61
|
+
tensor_parallel_size=1,
|
|
62
|
+
worker_use_ray=False,
|
|
63
|
+
)
|
|
64
|
+
speculative_config = SpeculativeConfig(
|
|
65
|
+
model='ngram',
|
|
66
|
+
num_speculative_tokens=5,
|
|
67
|
+
prompt_lookup_max=4,
|
|
68
|
+
)
|
|
69
|
+
vllm_config = VllmConfig(
|
|
70
|
+
model_config=model_config,
|
|
71
|
+
cache_config=cache_config,
|
|
72
|
+
scheduler_config=scheduler_config,
|
|
73
|
+
parallel_config=parallel_config,
|
|
74
|
+
speculative_config=speculative_config,
|
|
75
|
+
observability_config={},
|
|
76
|
+
additional_config={},
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
self.runner = TPUModelRunner(vllm_config,
|
|
80
|
+
devices=self.mock_devices)
|
|
81
|
+
|
|
82
|
+
def test_structured_decoding(self):
|
|
83
|
+
# 1. ===== Setup =====
|
|
84
|
+
# Configure runner for the test
|
|
85
|
+
self.runner.model_config.get_vocab_size = MagicMock(return_value=64)
|
|
86
|
+
self.runner._init_inputs() # re-initialize with new vocab size
|
|
87
|
+
|
|
88
|
+
# Mock device_array to avoid JAX sharding issues with MagicMock mesh
|
|
89
|
+
def mock_device_array(mesh, *args, sharding=None, **kwargs):
|
|
90
|
+
# Simply return the arguments without any sharding (skip mesh parameter)
|
|
91
|
+
if len(args) == 1 and isinstance(args[0], tuple):
|
|
92
|
+
return args[0] # Return tuple as is
|
|
93
|
+
elif len(args) == 1:
|
|
94
|
+
return args[0] # Return single array as is
|
|
95
|
+
else:
|
|
96
|
+
return args # Return all arguments as tuple
|
|
97
|
+
|
|
98
|
+
# Patch the centralized device_array function instead of runner's method
|
|
99
|
+
with patch(
|
|
100
|
+
'tpu_inference.runner.structured_decoding_manager.device_array',
|
|
101
|
+
side_effect=mock_device_array):
|
|
102
|
+
|
|
103
|
+
# Create a mock for sampling_params to avoid TypeErrors in add_request
|
|
104
|
+
mock_sampling_params = MagicMock()
|
|
105
|
+
mock_sampling_params.sampling_type = SamplingType.GREEDY
|
|
106
|
+
mock_sampling_params.temperature = 0.0
|
|
107
|
+
mock_sampling_params.top_p = 1.0
|
|
108
|
+
mock_sampling_params.top_k = -1
|
|
109
|
+
mock_sampling_params.min_tokens = 0
|
|
110
|
+
mock_sampling_params.logprobs = None
|
|
111
|
+
mock_sampling_params.logit_bias = None
|
|
112
|
+
mock_sampling_params.allowed_token_ids = set()
|
|
113
|
+
mock_sampling_params.bad_words_token_ids = None
|
|
114
|
+
mock_sampling_params.all_stop_token_ids = set()
|
|
115
|
+
|
|
116
|
+
# Add requests to the input batch
|
|
117
|
+
req1 = CachedRequestState(
|
|
118
|
+
req_id="req-1",
|
|
119
|
+
prompt_token_ids=[1],
|
|
120
|
+
output_token_ids=[],
|
|
121
|
+
sampling_params=mock_sampling_params,
|
|
122
|
+
block_ids=([1], ),
|
|
123
|
+
num_computed_tokens=1,
|
|
124
|
+
lora_request=None,
|
|
125
|
+
mm_features=[],
|
|
126
|
+
pooling_params=None,
|
|
127
|
+
generator=None,
|
|
128
|
+
)
|
|
129
|
+
req2 = CachedRequestState(
|
|
130
|
+
req_id="req-2",
|
|
131
|
+
prompt_token_ids=[2],
|
|
132
|
+
output_token_ids=[],
|
|
133
|
+
sampling_params=mock_sampling_params,
|
|
134
|
+
block_ids=([2], ),
|
|
135
|
+
num_computed_tokens=1,
|
|
136
|
+
lora_request=None,
|
|
137
|
+
mm_features=[],
|
|
138
|
+
pooling_params=None,
|
|
139
|
+
generator=None,
|
|
140
|
+
)
|
|
141
|
+
req3 = CachedRequestState(
|
|
142
|
+
req_id="req-3",
|
|
143
|
+
prompt_token_ids=[3],
|
|
144
|
+
output_token_ids=[],
|
|
145
|
+
sampling_params=mock_sampling_params,
|
|
146
|
+
block_ids=([3], ),
|
|
147
|
+
num_computed_tokens=1,
|
|
148
|
+
lora_request=None,
|
|
149
|
+
mm_features=[],
|
|
150
|
+
pooling_params=None,
|
|
151
|
+
generator=None,
|
|
152
|
+
)
|
|
153
|
+
self.runner.input_batch.add_request(req1) # index 0
|
|
154
|
+
self.runner.input_batch.add_request(req2) # index 1
|
|
155
|
+
self.runner.input_batch.add_request(req3) # index 2
|
|
156
|
+
num_reqs = 3
|
|
157
|
+
|
|
158
|
+
# Mock scheduler output for structured decoding
|
|
159
|
+
# req-1 and req-3 require structured decoding
|
|
160
|
+
mock_scheduler_output = MagicMock()
|
|
161
|
+
mock_scheduler_output.structured_output_request_ids = {
|
|
162
|
+
"req-1": 0, # maps req_id to index in grammar_bitmask
|
|
163
|
+
"req-3": 1,
|
|
164
|
+
}
|
|
165
|
+
# Bitmask: vocab_size=64, so 2 int32s per request
|
|
166
|
+
# Mask for req-1: allow tokens 0-31
|
|
167
|
+
mask1 = np.array([-1, 0], dtype=np.int32)
|
|
168
|
+
# Mask for req-3: allow tokens 32-63
|
|
169
|
+
mask2 = np.array([0, -1], dtype=np.int32)
|
|
170
|
+
mock_scheduler_output.grammar_bitmask = np.array([mask1, mask2])
|
|
171
|
+
|
|
172
|
+
# Mock logits
|
|
173
|
+
logits_shape = (num_reqs, self.runner.vocab_size)
|
|
174
|
+
mock_logits_device = jnp.ones(logits_shape, dtype=jnp.bfloat16)
|
|
175
|
+
|
|
176
|
+
# 2. ===== Test prepare_structured_decoding_input =====
|
|
177
|
+
(
|
|
178
|
+
require_struct_decoding, grammar_bitmask, arange
|
|
179
|
+
) = self.runner.structured_decoding_manager.prepare_structured_decoding_input(
|
|
180
|
+
mock_logits_device, mock_scheduler_output)
|
|
181
|
+
|
|
182
|
+
# Assertions for prepare_structured_decoding_input
|
|
183
|
+
# require_structured_out_cpu should be [True, False, True]
|
|
184
|
+
# because req-1 is at batch index 0, req-2 at 1, req-3 at 2
|
|
185
|
+
expected_require_struct = np.array([[True], [False], [True]],
|
|
186
|
+
dtype=np.bool_)
|
|
187
|
+
np.testing.assert_array_equal(np.array(require_struct_decoding),
|
|
188
|
+
expected_require_struct)
|
|
189
|
+
|
|
190
|
+
# grammar_bitmask_cpu should have mask1 at index 0, mask2 at index 2
|
|
191
|
+
expected_grammar_bitmask = np.zeros_like(
|
|
192
|
+
self.runner.grammar_bitmask_cpu[:num_reqs])
|
|
193
|
+
expected_grammar_bitmask[0] = mask1
|
|
194
|
+
expected_grammar_bitmask[2] = mask2
|
|
195
|
+
np.testing.assert_array_equal(np.array(grammar_bitmask),
|
|
196
|
+
expected_grammar_bitmask)
|
|
197
|
+
|
|
198
|
+
np.testing.assert_array_equal(np.array(arange),
|
|
199
|
+
np.arange(0, 32, dtype=np.int32))
|
|
200
|
+
|
|
201
|
+
# 3. ===== Test structured_decode_fn =====
|
|
202
|
+
# This function is jitted, so we call it with the device arrays
|
|
203
|
+
modified_logits = self.runner.structured_decoding_manager.structured_decode_fn(
|
|
204
|
+
require_struct_decoding, grammar_bitmask, mock_logits_device,
|
|
205
|
+
arange)
|
|
206
|
+
|
|
207
|
+
modified_logits_cpu = np.array(modified_logits)
|
|
208
|
+
|
|
209
|
+
# Assertions for structured_decode_fn
|
|
210
|
+
# Logits for req-1 (index 0) should be masked for tokens 32-63
|
|
211
|
+
assert np.all(modified_logits_cpu[0, :32] == 1.0)
|
|
212
|
+
assert np.all(modified_logits_cpu[0, 32:] == -np.inf)
|
|
213
|
+
|
|
214
|
+
# Logits for req-2 (index 1) should be unchanged
|
|
215
|
+
np.testing.assert_array_equal(modified_logits_cpu[1],
|
|
216
|
+
np.ones(self.runner.vocab_size))
|
|
217
|
+
|
|
218
|
+
# Logits for req-3 (index 2) should be masked for tokens 0-31
|
|
219
|
+
assert np.all(modified_logits_cpu[2, :32] == -np.inf)
|
|
220
|
+
assert np.all(modified_logits_cpu[2, 32:] == 1.0)
|