tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of tpu-inference might be problematic. Click here for more details.
- tests/__init__.py +13 -0
- tests/core/__init__.py +13 -0
- tests/core/test_disagg_utils.py +14 -0
- tests/core/test_dp_scheduler.py +650 -768
- tests/core/test_init.py +14 -0
- tests/distributed/__init__.py +13 -0
- tests/distributed/test_distributed_utils.py +120 -0
- tests/distributed/test_tpu_connector.py +478 -0
- tests/e2e/__init__.py +13 -0
- tests/e2e/test_async_scheduler.py +211 -0
- tests/e2e/test_data_parallel.py +289 -0
- tests/e2e/test_hybrid_kvcache.py +219 -0
- tests/e2e/test_local_disagg.py +257 -0
- tests/e2e/test_model_loader.py +268 -0
- tests/e2e/test_multi_modal_inference.py +111 -0
- tests/e2e/test_pipeline_parallel.py +265 -0
- tests/e2e/test_runai_model_streamer_loader.py +104 -0
- tests/e2e/test_sampling_params.py +269 -0
- tests/e2e/test_speculative_decoding.py +311 -0
- tests/e2e/test_structured_decoding.py +46 -0
- tests/executors/__init__.py +13 -0
- tests/executors/test_ray_distributed_executor.py +199 -0
- tests/experimental/__init__.py +13 -0
- tests/experimental/test_llama3_jax_stashed.py +208 -0
- tests/kernels/__init__.py +13 -0
- tests/kernels/collectives/__init__.py +13 -0
- tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
- tests/kernels/fused_moe_v1_test.py +14 -0
- tests/kernels/gmm_test.py +205 -0
- tests/kernels/mla_v1_test.py +143 -41
- tests/kernels/quantized_matmul_kernel_test.py +2 -34
- tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
- tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
- tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
- tests/layers/__init__.py +13 -0
- tests/layers/common/__init__.py +13 -0
- tests/layers/common/test_attention_interface.py +156 -0
- tests/layers/common/test_quantization.py +149 -0
- tests/layers/jax/__init__.py +13 -0
- tests/layers/jax/attention/__init__.py +13 -0
- tests/layers/jax/attention/test_common_attention.py +103 -0
- tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
- tests/layers/jax/attention/test_llama4_attention.py +135 -0
- tests/layers/jax/moe/__init__.py +13 -0
- tests/layers/jax/moe/test_deepseek_moe.py +235 -0
- tests/layers/jax/sample/__init__.py +13 -0
- tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
- tests/layers/jax/sample/test_sampling.py +115 -0
- tests/layers/jax/sample/test_sampling_metadata.py +254 -0
- tests/layers/jax/test_layers.py +155 -0
- tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
- tests/layers/jax/test_rope.py +93 -0
- tests/layers/jax/test_sharding.py +159 -0
- tests/layers/jax/test_transformer_block.py +152 -0
- tests/layers/vllm/__init__.py +13 -0
- tests/layers/vllm/test_attention.py +363 -0
- tests/layers/vllm/test_awq.py +405 -0
- tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
- tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
- tests/layers/vllm/test_fp8.py +17 -0
- tests/layers/vllm/test_mxfp4.py +312 -0
- tests/layers/vllm/test_unquantized.py +651 -0
- tests/layers/vllm/utils.py +87 -0
- tests/lora/__init__.py +13 -0
- tests/lora/conftest.py +14 -0
- tests/lora/test_bgmv.py +14 -0
- tests/lora/test_layers.py +21 -3
- tests/lora/test_lora.py +15 -1
- tests/lora/test_lora_perf.py +67 -0
- tests/models/__init__.py +13 -0
- tests/models/common/__init__.py +13 -0
- tests/models/common/test_model_loader.py +455 -0
- tests/models/jax/__init__.py +13 -0
- tests/models/jax/test_deepseek_v3.py +401 -0
- tests/models/jax/test_llama3.py +184 -0
- tests/models/jax/test_llama4.py +298 -0
- tests/models/jax/test_llama_eagle3.py +197 -0
- tests/models/jax/test_llama_guard_4.py +242 -0
- tests/models/jax/test_qwen2.py +172 -0
- tests/models/jax/test_qwen2_5_vl.py +605 -0
- tests/models/jax/test_qwen3.py +169 -0
- tests/models/jax/test_weight_loading.py +180 -0
- tests/models/jax/utils/__init__.py +13 -0
- tests/models/jax/utils/test_multi_modal_utils.py +212 -0
- tests/platforms/__init__.py +13 -0
- tests/platforms/test_tpu_platform.py +54 -0
- tests/runner/__init__.py +13 -0
- tests/runner/test_block_table.py +395 -0
- tests/runner/test_input_batch.py +226 -0
- tests/runner/test_kv_cache.py +220 -0
- tests/runner/test_kv_cache_manager.py +498 -0
- tests/runner/test_multimodal_manager.py +429 -0
- tests/runner/test_persistent_batch_manager.py +84 -0
- tests/runner/test_speculative_decoding_manager.py +368 -0
- tests/runner/test_structured_decoding_manager.py +220 -0
- tests/runner/test_tpu_runner.py +261 -0
- tests/runner/test_tpu_runner_dp.py +1099 -0
- tests/runner/test_tpu_runner_mesh.py +200 -0
- tests/runner/test_utils.py +411 -0
- tests/spec_decode/__init__.py +13 -0
- tests/spec_decode/test_eagle3.py +311 -0
- tests/test_base.py +14 -0
- tests/test_envs.py +78 -1
- tests/test_tpu_info.py +14 -0
- tests/test_utils.py +1 -43
- tests/worker/__init__.py +13 -0
- tests/worker/tpu_worker_test.py +414 -0
- tpu_inference/__init__.py +14 -0
- tpu_inference/core/__init__.py +13 -0
- tpu_inference/core/sched/__init__.py +13 -0
- tpu_inference/core/sched/dp_scheduler.py +372 -56
- tpu_inference/distributed/__init__.py +13 -0
- tpu_inference/distributed/jax_parallel_state.py +14 -0
- tpu_inference/distributed/tpu_connector.py +14 -9
- tpu_inference/distributed/utils.py +56 -4
- tpu_inference/envs.py +38 -7
- tpu_inference/executors/__init__.py +13 -0
- tpu_inference/executors/ray_distributed_executor.py +17 -0
- tpu_inference/experimental/__init__.py +13 -0
- tpu_inference/experimental/llama3_jax_stashed.py +14 -0
- tpu_inference/kernels/__init__.py +13 -0
- tpu_inference/kernels/collectives/__init__.py +13 -0
- tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
- tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
- tpu_inference/kernels/flash_attention/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
- tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
- tpu_inference/kernels/megablox/__init__.py +13 -0
- tpu_inference/kernels/megablox/common.py +54 -0
- tpu_inference/kernels/megablox/gmm.py +646 -0
- tpu_inference/kernels/mla/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/__init__.py +13 -0
- tpu_inference/kernels/mla/v1/kernel.py +117 -145
- tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
- tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
- tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
- tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
- tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
- tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
- tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
- tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
- tpu_inference/layers/__init__.py +13 -0
- tpu_inference/layers/common/__init__.py +13 -0
- tpu_inference/layers/common/attention_interface.py +26 -19
- tpu_inference/layers/common/attention_metadata.py +14 -0
- tpu_inference/layers/common/quant_methods.py +15 -0
- tpu_inference/layers/common/quantization.py +270 -0
- tpu_inference/layers/common/sharding.py +28 -5
- tpu_inference/layers/jax/__init__.py +13 -0
- tpu_inference/layers/jax/attention/__init__.py +13 -0
- tpu_inference/layers/jax/attention/attention.py +19 -6
- tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
- tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
- tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
- tpu_inference/layers/jax/base.py +14 -0
- tpu_inference/layers/jax/constants.py +13 -0
- tpu_inference/layers/jax/layers.py +14 -0
- tpu_inference/layers/jax/misc.py +14 -0
- tpu_inference/layers/jax/moe/__init__.py +13 -0
- tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
- tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
- tpu_inference/layers/jax/moe/moe.py +43 -3
- tpu_inference/layers/jax/pp_utils.py +53 -0
- tpu_inference/layers/jax/rope.py +14 -0
- tpu_inference/layers/jax/rope_interface.py +14 -0
- tpu_inference/layers/jax/sample/__init__.py +13 -0
- tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
- tpu_inference/layers/jax/sample/sampling.py +15 -1
- tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
- tpu_inference/layers/jax/transformer_block.py +14 -0
- tpu_inference/layers/vllm/__init__.py +13 -0
- tpu_inference/layers/vllm/attention.py +4 -4
- tpu_inference/layers/vllm/fused_moe.py +210 -260
- tpu_inference/layers/vllm/linear_common.py +57 -22
- tpu_inference/layers/vllm/quantization/__init__.py +16 -0
- tpu_inference/layers/vllm/quantization/awq.py +15 -1
- tpu_inference/layers/vllm/quantization/common.py +33 -18
- tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
- tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
- tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
- tpu_inference/layers/vllm/quantization/fp8.py +118 -0
- tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
- tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
- tpu_inference/layers/vllm/sharding.py +21 -4
- tpu_inference/lora/__init__.py +13 -0
- tpu_inference/lora/torch_lora_ops.py +8 -13
- tpu_inference/models/__init__.py +13 -0
- tpu_inference/models/common/__init__.py +13 -0
- tpu_inference/models/common/model_loader.py +74 -35
- tpu_inference/models/jax/__init__.py +13 -0
- tpu_inference/models/jax/deepseek_v3.py +267 -157
- tpu_inference/models/jax/gpt_oss.py +26 -10
- tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
- tpu_inference/models/jax/llama3.py +99 -36
- tpu_inference/models/jax/llama4.py +14 -0
- tpu_inference/models/jax/llama_eagle3.py +14 -0
- tpu_inference/models/jax/llama_guard_4.py +15 -1
- tpu_inference/models/jax/qwen2.py +17 -2
- tpu_inference/models/jax/qwen2_5_vl.py +18 -4
- tpu_inference/models/jax/qwen3.py +17 -2
- tpu_inference/models/jax/utils/__init__.py +13 -0
- tpu_inference/models/jax/utils/file_utils.py +14 -0
- tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
- tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
- tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
- tpu_inference/models/jax/utils/weight_utils.py +39 -2
- tpu_inference/models/vllm/__init__.py +13 -0
- tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
- tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
- tpu_inference/platforms/__init__.py +14 -0
- tpu_inference/platforms/tpu_platform.py +47 -64
- tpu_inference/runner/__init__.py +13 -0
- tpu_inference/runner/compilation_manager.py +72 -37
- tpu_inference/runner/kv_cache.py +54 -20
- tpu_inference/runner/kv_cache_manager.py +46 -17
- tpu_inference/runner/lora_utils.py +14 -0
- tpu_inference/runner/multimodal_manager.py +15 -1
- tpu_inference/runner/persistent_batch_manager.py +14 -0
- tpu_inference/runner/speculative_decoding_manager.py +14 -0
- tpu_inference/runner/structured_decoding_manager.py +14 -0
- tpu_inference/runner/tpu_runner.py +44 -17
- tpu_inference/spec_decode/__init__.py +13 -0
- tpu_inference/spec_decode/jax/__init__.py +13 -0
- tpu_inference/spec_decode/jax/eagle3.py +13 -0
- tpu_inference/tpu_info.py +14 -0
- tpu_inference/utils.py +42 -36
- tpu_inference/worker/__init__.py +13 -0
- tpu_inference/worker/tpu_worker.py +63 -50
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
- tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
- tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
- tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
- tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
- tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
- tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
- {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1624 @@
|
|
|
1
|
+
# Copyright 2025 Google LLC
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
"""
|
|
15
|
+
Tests for the JAX-based rejection sampler for speculative decoding on TPU.
|
|
16
|
+
This test suite is structured to mirror the GPU rejection sampler tests.
|
|
17
|
+
"""
|
|
18
|
+
from dataclasses import dataclass
|
|
19
|
+
from typing import List, Tuple
|
|
20
|
+
|
|
21
|
+
import jax
|
|
22
|
+
import jax.numpy as jnp
|
|
23
|
+
import numpy as np
|
|
24
|
+
import pytest
|
|
25
|
+
|
|
26
|
+
from tpu_inference.layers.jax.sample.rejection_sampler import (
|
|
27
|
+
PLACEHOLDER_TOKEN_ID, RejectionSampler)
|
|
28
|
+
from tpu_inference.layers.jax.sample.sampling_metadata import \
|
|
29
|
+
TPUSupportedSamplingMetadata
|
|
30
|
+
|
|
31
|
+
# ======================== CONSTANTS ========================
|
|
32
|
+
|
|
33
|
+
PAD_TOKEN_ID = -999 # Padding token for draft_token_ids
|
|
34
|
+
VOCAB_SIZE = 128 # Default vocabulary size for tests
|
|
35
|
+
DEFAULT_PADDING_FACTOR = 1.5 # Default padding factor for padded tests
|
|
36
|
+
|
|
37
|
+
# ======================== DATA STRUCTURES ========================
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class RejectionSamplerTestCase:
|
|
42
|
+
"""Test case data structure for rejection sampler scenarios."""
|
|
43
|
+
name: str
|
|
44
|
+
draft_tokens: List[int]
|
|
45
|
+
target_tokens: List[int]
|
|
46
|
+
num_draft_per_seq: List[int] # number of draft tokens per sequence
|
|
47
|
+
bonus_tokens: List[int]
|
|
48
|
+
expected: List[List[int]]
|
|
49
|
+
description: str = ""
|
|
50
|
+
use_padding: bool = False # Whether to add padding to draft tokens
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
# ======================== TEST DATA FACTORY ========================
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class TestDataFactory:
|
|
57
|
+
"""Factory class for generating test cases."""
|
|
58
|
+
|
|
59
|
+
@staticmethod
|
|
60
|
+
def create_test_case(
|
|
61
|
+
name: str,
|
|
62
|
+
draft_tokens: List[int],
|
|
63
|
+
target_tokens: List[int],
|
|
64
|
+
num_draft_per_seq: List[int],
|
|
65
|
+
bonus_tokens: List[int],
|
|
66
|
+
expected: List[List[int]],
|
|
67
|
+
description: str = "",
|
|
68
|
+
use_padding: bool = False) -> RejectionSamplerTestCase:
|
|
69
|
+
"""Create a single test case."""
|
|
70
|
+
return RejectionSamplerTestCase(name=name,
|
|
71
|
+
draft_tokens=draft_tokens,
|
|
72
|
+
target_tokens=target_tokens,
|
|
73
|
+
num_draft_per_seq=num_draft_per_seq,
|
|
74
|
+
bonus_tokens=bonus_tokens,
|
|
75
|
+
expected=expected,
|
|
76
|
+
description=description
|
|
77
|
+
or name.replace("_", " ").title(),
|
|
78
|
+
use_padding=use_padding)
|
|
79
|
+
|
|
80
|
+
@classmethod
|
|
81
|
+
def create_with_padding_variant(
|
|
82
|
+
cls,
|
|
83
|
+
name: str,
|
|
84
|
+
draft_tokens: List[int],
|
|
85
|
+
target_tokens: List[int],
|
|
86
|
+
num_draft_per_seq: List[int],
|
|
87
|
+
bonus_tokens: List[int],
|
|
88
|
+
expected: List[List[int]],
|
|
89
|
+
description: str = "") -> List[RejectionSamplerTestCase]:
|
|
90
|
+
"""Create both normal and padded versions of a test case."""
|
|
91
|
+
test_cases = []
|
|
92
|
+
|
|
93
|
+
# Create normal version
|
|
94
|
+
test_cases.append(
|
|
95
|
+
cls.create_test_case(name=name,
|
|
96
|
+
draft_tokens=draft_tokens,
|
|
97
|
+
target_tokens=target_tokens,
|
|
98
|
+
num_draft_per_seq=num_draft_per_seq,
|
|
99
|
+
bonus_tokens=bonus_tokens,
|
|
100
|
+
expected=expected,
|
|
101
|
+
description=description))
|
|
102
|
+
|
|
103
|
+
# Create padded version if there are tokens
|
|
104
|
+
if draft_tokens:
|
|
105
|
+
test_cases.append(
|
|
106
|
+
cls.create_test_case(
|
|
107
|
+
name=f"{name}_padded",
|
|
108
|
+
draft_tokens=draft_tokens,
|
|
109
|
+
target_tokens=target_tokens,
|
|
110
|
+
num_draft_per_seq=num_draft_per_seq,
|
|
111
|
+
bonus_tokens=bonus_tokens,
|
|
112
|
+
expected=expected,
|
|
113
|
+
description=f"{description} (with padding)",
|
|
114
|
+
use_padding=True))
|
|
115
|
+
|
|
116
|
+
return test_cases
|
|
117
|
+
|
|
118
|
+
@classmethod
|
|
119
|
+
def get_basic_test_cases(cls) -> List[RejectionSamplerTestCase]:
|
|
120
|
+
"""Generate basic functionality test cases."""
|
|
121
|
+
test_cases = []
|
|
122
|
+
|
|
123
|
+
# Perfect match
|
|
124
|
+
test_cases.extend(
|
|
125
|
+
cls.create_with_padding_variant(
|
|
126
|
+
name="perfect_match",
|
|
127
|
+
draft_tokens=[1, 2, 3],
|
|
128
|
+
target_tokens=[1, 2, 3],
|
|
129
|
+
num_draft_per_seq=[3],
|
|
130
|
+
bonus_tokens=[4],
|
|
131
|
+
expected=[[1, 2, 3, 4]],
|
|
132
|
+
description="Draft tokens perfectly match target argmax"))
|
|
133
|
+
|
|
134
|
+
# Early mismatch
|
|
135
|
+
test_cases.extend(
|
|
136
|
+
cls.create_with_padding_variant(
|
|
137
|
+
name="early_mismatch",
|
|
138
|
+
draft_tokens=[1, 2, 3],
|
|
139
|
+
target_tokens=[1, 5, 3],
|
|
140
|
+
num_draft_per_seq=[3],
|
|
141
|
+
bonus_tokens=[4],
|
|
142
|
+
expected=[[1, 5]],
|
|
143
|
+
description="Mismatch at position 1"))
|
|
144
|
+
|
|
145
|
+
# Multiple sequences
|
|
146
|
+
test_cases.extend(
|
|
147
|
+
cls.create_with_padding_variant(
|
|
148
|
+
name="multiple_sequences",
|
|
149
|
+
draft_tokens=[1, 2, 3, 4],
|
|
150
|
+
target_tokens=[1, 2, 3, 7],
|
|
151
|
+
num_draft_per_seq=[2, 2],
|
|
152
|
+
bonus_tokens=[5, 6],
|
|
153
|
+
expected=[[1, 2, 5], [3, 7]],
|
|
154
|
+
description="Multiple sequences with mixed results"))
|
|
155
|
+
|
|
156
|
+
# Single token sequence
|
|
157
|
+
test_cases.extend(
|
|
158
|
+
cls.create_with_padding_variant(
|
|
159
|
+
name="single_token_sequence",
|
|
160
|
+
draft_tokens=[1],
|
|
161
|
+
target_tokens=[1],
|
|
162
|
+
num_draft_per_seq=[1],
|
|
163
|
+
bonus_tokens=[2],
|
|
164
|
+
expected=[[1, 2]],
|
|
165
|
+
description="Single token sequence with perfect match"))
|
|
166
|
+
|
|
167
|
+
# Empty sequence (no padding variant)
|
|
168
|
+
test_cases.append(
|
|
169
|
+
cls.create_test_case(
|
|
170
|
+
name="empty_sequence",
|
|
171
|
+
draft_tokens=[],
|
|
172
|
+
target_tokens=[],
|
|
173
|
+
num_draft_per_seq=[0],
|
|
174
|
+
bonus_tokens=[5],
|
|
175
|
+
expected=[[5]],
|
|
176
|
+
description="Empty sequence gets bonus token"))
|
|
177
|
+
|
|
178
|
+
return test_cases
|
|
179
|
+
|
|
180
|
+
@classmethod
|
|
181
|
+
def get_variable_length_test_cases(cls) -> List[RejectionSamplerTestCase]:
|
|
182
|
+
"""Generate variable length test cases."""
|
|
183
|
+
test_cases = []
|
|
184
|
+
|
|
185
|
+
# Variable length sequences
|
|
186
|
+
test_cases.extend(
|
|
187
|
+
cls.create_with_padding_variant(
|
|
188
|
+
name="variable_length_sequences",
|
|
189
|
+
draft_tokens=[1, 2, 3],
|
|
190
|
+
target_tokens=[1, 5, 3],
|
|
191
|
+
num_draft_per_seq=[2, 1],
|
|
192
|
+
bonus_tokens=[6, 7],
|
|
193
|
+
expected=[[1, 5], [3, 7]],
|
|
194
|
+
description="Sequences with different lengths"))
|
|
195
|
+
|
|
196
|
+
# All different lengths
|
|
197
|
+
test_cases.extend(
|
|
198
|
+
cls.create_with_padding_variant(
|
|
199
|
+
name="all_different_lengths",
|
|
200
|
+
draft_tokens=[1, 2, 3, 4, 5, 6],
|
|
201
|
+
target_tokens=[1, 2, 3, 4, 5, 6],
|
|
202
|
+
num_draft_per_seq=[1, 2, 3],
|
|
203
|
+
bonus_tokens=[7, 9, 10],
|
|
204
|
+
expected=[[1, 7], [2, 3, 9], [4, 5, 6, 10]],
|
|
205
|
+
description="All sequences have different lengths"))
|
|
206
|
+
|
|
207
|
+
# Mixed sequence lengths
|
|
208
|
+
test_cases.extend(
|
|
209
|
+
cls.create_with_padding_variant(
|
|
210
|
+
name="mixed_sequence_lengths",
|
|
211
|
+
draft_tokens=[1, 2, 3, 4, 5],
|
|
212
|
+
target_tokens=[1, 2, 3, 7, 5],
|
|
213
|
+
num_draft_per_seq=[2, 3],
|
|
214
|
+
bonus_tokens=[6, 8],
|
|
215
|
+
expected=[[1, 2, 6], [3, 7]],
|
|
216
|
+
description="Mixed lengths with different outcomes"))
|
|
217
|
+
|
|
218
|
+
return test_cases
|
|
219
|
+
|
|
220
|
+
@classmethod
|
|
221
|
+
def get_edge_case_test_cases(cls) -> List[RejectionSamplerTestCase]:
|
|
222
|
+
"""Generate edge case test cases."""
|
|
223
|
+
test_cases = []
|
|
224
|
+
|
|
225
|
+
# Zero length mixed
|
|
226
|
+
test_cases.extend(
|
|
227
|
+
cls.create_with_padding_variant(
|
|
228
|
+
name="zero_length_mixed",
|
|
229
|
+
draft_tokens=[1, 2],
|
|
230
|
+
target_tokens=[1, 2],
|
|
231
|
+
num_draft_per_seq=[0, 2],
|
|
232
|
+
bonus_tokens=[5, 6],
|
|
233
|
+
expected=[[5], [1, 2, 6]],
|
|
234
|
+
description="Zero-length sequence mixed with normal"))
|
|
235
|
+
|
|
236
|
+
# All zero length (no padding variant)
|
|
237
|
+
test_cases.append(
|
|
238
|
+
cls.create_test_case(name="all_zero_length",
|
|
239
|
+
draft_tokens=[],
|
|
240
|
+
target_tokens=[],
|
|
241
|
+
num_draft_per_seq=[0, 0],
|
|
242
|
+
bonus_tokens=[5, 6],
|
|
243
|
+
expected=[[5], [6]],
|
|
244
|
+
description="All sequences are zero-length"))
|
|
245
|
+
|
|
246
|
+
# Immediate rejection
|
|
247
|
+
test_cases.extend(
|
|
248
|
+
cls.create_with_padding_variant(
|
|
249
|
+
name="immediate_rejection",
|
|
250
|
+
draft_tokens=[1, 2, 3, 4, 5, 6],
|
|
251
|
+
target_tokens=[9, 2, 3, 4, 5, 6],
|
|
252
|
+
num_draft_per_seq=[3, 2, 1],
|
|
253
|
+
bonus_tokens=[10, 11, 12],
|
|
254
|
+
expected=[[9], [4, 5, 11], [6, 12]],
|
|
255
|
+
description="Mixed immediate rejection and perfect matches"))
|
|
256
|
+
|
|
257
|
+
# First token mismatch
|
|
258
|
+
test_cases.extend(
|
|
259
|
+
cls.create_with_padding_variant(
|
|
260
|
+
name="first_token_mismatch",
|
|
261
|
+
draft_tokens=[1],
|
|
262
|
+
target_tokens=[2],
|
|
263
|
+
num_draft_per_seq=[1],
|
|
264
|
+
bonus_tokens=[3],
|
|
265
|
+
expected=[[2]],
|
|
266
|
+
description="Single token mismatch"))
|
|
267
|
+
|
|
268
|
+
return test_cases
|
|
269
|
+
|
|
270
|
+
@classmethod
|
|
271
|
+
def get_all_test_cases(cls) -> List[RejectionSamplerTestCase]:
|
|
272
|
+
"""Get all test cases including basic, variable length, and edge cases."""
|
|
273
|
+
all_cases = []
|
|
274
|
+
all_cases.extend(cls.get_basic_test_cases())
|
|
275
|
+
all_cases.extend(cls.get_variable_length_test_cases())
|
|
276
|
+
all_cases.extend(cls.get_edge_case_test_cases())
|
|
277
|
+
return all_cases
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
# ======================== TEST HELPERS ========================
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
class RejectionSamplerTestHelper:
|
|
284
|
+
"""Helper class for rejection sampler tests."""
|
|
285
|
+
|
|
286
|
+
@staticmethod
|
|
287
|
+
def create_target_logits_from_tokens(
|
|
288
|
+
target_token_ids: List[int],
|
|
289
|
+
vocab_size: int = VOCAB_SIZE) -> jnp.ndarray:
|
|
290
|
+
"""
|
|
291
|
+
Create target logits that will produce desired token ids on argmax.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
target_token_ids: List of target token IDs
|
|
295
|
+
vocab_size: Size of the vocabulary
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
JAX array of target logits
|
|
299
|
+
"""
|
|
300
|
+
num_tokens = len(target_token_ids)
|
|
301
|
+
if num_tokens == 0:
|
|
302
|
+
return jnp.empty((0, vocab_size), dtype=jnp.float32)
|
|
303
|
+
|
|
304
|
+
# Create target logits with low values
|
|
305
|
+
target_logits = jnp.full((num_tokens, vocab_size),
|
|
306
|
+
-100.0,
|
|
307
|
+
dtype=jnp.float32)
|
|
308
|
+
|
|
309
|
+
# Set high values at desired token positions to make them the argmax
|
|
310
|
+
for i, token_id in enumerate(target_token_ids):
|
|
311
|
+
target_logits = target_logits.at[i, token_id].set(100.0)
|
|
312
|
+
|
|
313
|
+
return target_logits
|
|
314
|
+
|
|
315
|
+
@staticmethod
|
|
316
|
+
def create_sampling_metadata(
|
|
317
|
+
all_greedy: bool = True,
|
|
318
|
+
batch_size: int = 1,
|
|
319
|
+
top_k: int = -1,
|
|
320
|
+
top_p: float = 1.0,
|
|
321
|
+
temperature: float = 1.0,
|
|
322
|
+
) -> TPUSupportedSamplingMetadata:
|
|
323
|
+
"""
|
|
324
|
+
Create TPU sampling metadata object.
|
|
325
|
+
"""
|
|
326
|
+
return TPUSupportedSamplingMetadata(
|
|
327
|
+
do_sampling=not all_greedy,
|
|
328
|
+
logprobs=False,
|
|
329
|
+
top_k=jnp.full((batch_size, ), top_k, dtype=jnp.int32),
|
|
330
|
+
top_p=jnp.full((batch_size, ), top_p, dtype=jnp.float32),
|
|
331
|
+
temperature=jnp.full((batch_size, ),
|
|
332
|
+
temperature,
|
|
333
|
+
dtype=jnp.float32),
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
@staticmethod
|
|
337
|
+
def create_padded_draft_tokens(
|
|
338
|
+
draft_tokens: List[int],
|
|
339
|
+
padding_factor: float = DEFAULT_PADDING_FACTOR) -> jnp.ndarray:
|
|
340
|
+
"""
|
|
341
|
+
Create padded draft tokens array.
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
draft_tokens: List of draft tokens
|
|
345
|
+
padding_factor: Factor to determine padding length
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
JAX array of padded tokens
|
|
349
|
+
"""
|
|
350
|
+
if not draft_tokens:
|
|
351
|
+
return jnp.array([], dtype=jnp.int32)
|
|
352
|
+
|
|
353
|
+
# Calculate padded length (at least 50% more than actual tokens)
|
|
354
|
+
actual_length = len(draft_tokens)
|
|
355
|
+
padded_length = max(actual_length + 2,
|
|
356
|
+
int(actual_length * padding_factor))
|
|
357
|
+
|
|
358
|
+
# Create padded array
|
|
359
|
+
padded_tokens = [PAD_TOKEN_ID] * padded_length
|
|
360
|
+
|
|
361
|
+
# Copy actual tokens to the beginning
|
|
362
|
+
for i, token in enumerate(draft_tokens):
|
|
363
|
+
padded_tokens[i] = token
|
|
364
|
+
|
|
365
|
+
return jnp.array(padded_tokens, dtype=jnp.int32)
|
|
366
|
+
|
|
367
|
+
@staticmethod
|
|
368
|
+
def prepare_test_inputs(
|
|
369
|
+
test_case: RejectionSamplerTestCase,
|
|
370
|
+
vocab_size: int = VOCAB_SIZE
|
|
371
|
+
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, int]:
|
|
372
|
+
"""
|
|
373
|
+
Prepare inputs for rejection sampler test.
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
test_case: Test case with input data
|
|
377
|
+
vocab_size: Vocabulary size
|
|
378
|
+
|
|
379
|
+
Returns:
|
|
380
|
+
Tuple of (draft_token_ids, target_logits, num_draft_tokens,
|
|
381
|
+
bonus_token_ids)
|
|
382
|
+
"""
|
|
383
|
+
helper = RejectionSamplerTestHelper()
|
|
384
|
+
|
|
385
|
+
# Prepare draft tokens (with or without padding)
|
|
386
|
+
if test_case.use_padding and test_case.draft_tokens:
|
|
387
|
+
# For padded inputs, simulate how a real system would handle padding
|
|
388
|
+
padded_draft_tokens = helper.create_padded_draft_tokens(
|
|
389
|
+
test_case.draft_tokens)
|
|
390
|
+
|
|
391
|
+
# Extract only the actual tokens
|
|
392
|
+
num_draft_tokens = jnp.array(test_case.num_draft_per_seq,
|
|
393
|
+
dtype=jnp.int32)
|
|
394
|
+
total_actual_tokens = int(jnp.sum(num_draft_tokens))
|
|
395
|
+
|
|
396
|
+
# Extract only the first total_actual_tokens from the padded array
|
|
397
|
+
draft_token_ids = padded_draft_tokens[:total_actual_tokens]
|
|
398
|
+
target_logits = helper.create_target_logits_from_tokens(
|
|
399
|
+
test_case.target_tokens, vocab_size)
|
|
400
|
+
else:
|
|
401
|
+
draft_token_ids = jnp.array(test_case.draft_tokens,
|
|
402
|
+
dtype=jnp.int32)
|
|
403
|
+
target_logits = helper.create_target_logits_from_tokens(
|
|
404
|
+
test_case.target_tokens, vocab_size)
|
|
405
|
+
num_draft_tokens = jnp.array(test_case.num_draft_per_seq,
|
|
406
|
+
dtype=jnp.int32)
|
|
407
|
+
|
|
408
|
+
bonus_token_ids = jnp.array(test_case.bonus_tokens, dtype=jnp.int32)
|
|
409
|
+
|
|
410
|
+
return (draft_token_ids, target_logits, num_draft_tokens,
|
|
411
|
+
bonus_token_ids)
|
|
412
|
+
|
|
413
|
+
@staticmethod
|
|
414
|
+
def run_rejection_sampler_test(
|
|
415
|
+
rejection_sampler: RejectionSampler,
|
|
416
|
+
test_case: RejectionSamplerTestCase,
|
|
417
|
+
vocab_size: int = VOCAB_SIZE,
|
|
418
|
+
) -> None:
|
|
419
|
+
"""
|
|
420
|
+
Run a rejection sampler test from test case data.
|
|
421
|
+
|
|
422
|
+
Args:
|
|
423
|
+
rejection_sampler: RejectionSampler instance
|
|
424
|
+
test_case: Test case to run
|
|
425
|
+
vocab_size: Vocabulary size
|
|
426
|
+
"""
|
|
427
|
+
helper = RejectionSamplerTestHelper()
|
|
428
|
+
metadata = helper.create_sampling_metadata(all_greedy=True)
|
|
429
|
+
|
|
430
|
+
# Prepare inputs
|
|
431
|
+
(draft_token_ids, target_logits, num_draft_tokens,
|
|
432
|
+
bonus_token_ids) = helper.prepare_test_inputs(test_case, vocab_size)
|
|
433
|
+
|
|
434
|
+
# Call the rejection sampler
|
|
435
|
+
output = rejection_sampler(
|
|
436
|
+
draft_token_ids=draft_token_ids,
|
|
437
|
+
num_draft_tokens=num_draft_tokens,
|
|
438
|
+
draft_probs=None,
|
|
439
|
+
target_logits=target_logits,
|
|
440
|
+
bonus_token_ids=bonus_token_ids,
|
|
441
|
+
sampling_metadata=metadata,
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
# Parse the output
|
|
445
|
+
parsed_output = rejection_sampler.parse_output(
|
|
446
|
+
output,
|
|
447
|
+
vocab_size=vocab_size,
|
|
448
|
+
num_draft_tokens_cpu=np.asarray(num_draft_tokens),
|
|
449
|
+
batch_size=len(num_draft_tokens),
|
|
450
|
+
padded_tokens_length=int(sum(num_draft_tokens)))
|
|
451
|
+
|
|
452
|
+
assert parsed_output == test_case.expected, \
|
|
453
|
+
f"Test '{test_case.name}': Expected {test_case.expected}, got {parsed_output}"
|
|
454
|
+
|
|
455
|
+
|
|
456
|
+
# ======================== FIXTURES ========================
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
@pytest.fixture
|
|
460
|
+
def rejection_sampler():
|
|
461
|
+
"""Fixture for the RejectionSampler."""
|
|
462
|
+
return RejectionSampler()
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
@pytest.fixture
|
|
466
|
+
def test_helper():
|
|
467
|
+
"""Fixture for the test helper."""
|
|
468
|
+
return RejectionSamplerTestHelper()
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
@pytest.fixture
|
|
472
|
+
def test_factory():
|
|
473
|
+
"""Fixture for the test data factory."""
|
|
474
|
+
return TestDataFactory()
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
# ======================== TEST CLASSES ========================
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
class TestRejectionSampler:
|
|
481
|
+
"""Comprehensive test suite for rejection sampler."""
|
|
482
|
+
|
|
483
|
+
# =============== Basic Functionality Tests ===============
|
|
484
|
+
|
|
485
|
+
@pytest.mark.parametrize("test_case",
|
|
486
|
+
TestDataFactory.get_all_test_cases(),
|
|
487
|
+
ids=lambda tc: tc.name)
|
|
488
|
+
def test_rejection_sampler_scenarios(self, rejection_sampler, test_case):
|
|
489
|
+
"""Test all rejection sampler scenarios including padded versions."""
|
|
490
|
+
RejectionSamplerTestHelper.run_rejection_sampler_test(
|
|
491
|
+
rejection_sampler, test_case)
|
|
492
|
+
|
|
493
|
+
def test_multiple_mismatches(self, rejection_sampler, test_factory):
|
|
494
|
+
"""Test handling multiple sequences where both have mismatches."""
|
|
495
|
+
test_cases = test_factory.create_with_padding_variant(
|
|
496
|
+
name="multiple_mismatches",
|
|
497
|
+
draft_tokens=[1, 2, 3, 4, 5, 6],
|
|
498
|
+
target_tokens=[1, 2, 7, 4, 8, 6],
|
|
499
|
+
num_draft_per_seq=[3, 3],
|
|
500
|
+
bonus_tokens=[8, 9],
|
|
501
|
+
expected=[[1, 2, 7], [4, 8]],
|
|
502
|
+
description="Both sequences have mismatches")
|
|
503
|
+
|
|
504
|
+
for test_case in test_cases:
|
|
505
|
+
RejectionSamplerTestHelper.run_rejection_sampler_test(
|
|
506
|
+
rejection_sampler, test_case)
|
|
507
|
+
|
|
508
|
+
# =============== Parse Output Tests ===============
|
|
509
|
+
|
|
510
|
+
def test_parse_output_basic(self, rejection_sampler):
|
|
511
|
+
"""Test the parse_output method with basic flattened format."""
|
|
512
|
+
vocab_size = VOCAB_SIZE
|
|
513
|
+
|
|
514
|
+
# Create flattened output: [main_tokens..., bonus_tokens...]
|
|
515
|
+
main_tokens = jnp.array([10, 20, 30, 50, 60], dtype=jnp.int32)
|
|
516
|
+
bonus_tokens = jnp.array([40, 70], dtype=jnp.int32)
|
|
517
|
+
output_token_ids = jnp.concatenate([main_tokens, bonus_tokens])
|
|
518
|
+
|
|
519
|
+
num_draft_tokens = jnp.array([3, 2], dtype=jnp.int32)
|
|
520
|
+
|
|
521
|
+
parsed_output = rejection_sampler.parse_output(
|
|
522
|
+
output_token_ids,
|
|
523
|
+
vocab_size,
|
|
524
|
+
num_draft_tokens_cpu=np.asarray(num_draft_tokens),
|
|
525
|
+
batch_size=len(num_draft_tokens),
|
|
526
|
+
padded_tokens_length=int(sum(num_draft_tokens)))
|
|
527
|
+
|
|
528
|
+
expected = [[10, 20, 30, 40], [50, 60, 70]]
|
|
529
|
+
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
|
|
530
|
+
|
|
531
|
+
def test_parse_output_with_placeholders(self, rejection_sampler):
|
|
532
|
+
"""Test parse_output with rejected tokens (placeholders)."""
|
|
533
|
+
vocab_size = VOCAB_SIZE
|
|
534
|
+
|
|
535
|
+
# Test with rejected tokens (placeholders)
|
|
536
|
+
main_tokens = jnp.array(
|
|
537
|
+
[10, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID, 20, 30],
|
|
538
|
+
dtype=jnp.int32)
|
|
539
|
+
bonus_tokens = jnp.array([PLACEHOLDER_TOKEN_ID, 40], dtype=jnp.int32)
|
|
540
|
+
output_token_ids = jnp.concatenate([main_tokens, bonus_tokens])
|
|
541
|
+
|
|
542
|
+
num_draft_tokens = jnp.array([3, 2], dtype=jnp.int32)
|
|
543
|
+
|
|
544
|
+
parsed_output = rejection_sampler.parse_output(
|
|
545
|
+
output_token_ids,
|
|
546
|
+
vocab_size,
|
|
547
|
+
num_draft_tokens_cpu=np.asarray(num_draft_tokens),
|
|
548
|
+
batch_size=len(num_draft_tokens),
|
|
549
|
+
padded_tokens_length=int(sum(num_draft_tokens)))
|
|
550
|
+
|
|
551
|
+
expected = [[10], [20, 30, 40]]
|
|
552
|
+
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
|
|
553
|
+
|
|
554
|
+
def test_parse_output_invalid_tokens(self, rejection_sampler):
|
|
555
|
+
"""Test parse_output with tokens outside vocab size."""
|
|
556
|
+
vocab_size = VOCAB_SIZE
|
|
557
|
+
|
|
558
|
+
# Test with tokens outside vocab size
|
|
559
|
+
main_tokens = jnp.array([10, vocab_size + 1, 20], dtype=jnp.int32)
|
|
560
|
+
bonus_tokens = jnp.array([vocab_size + 2], dtype=jnp.int32)
|
|
561
|
+
output_token_ids = jnp.concatenate([main_tokens, bonus_tokens])
|
|
562
|
+
|
|
563
|
+
num_draft_tokens = jnp.array([3], dtype=jnp.int32)
|
|
564
|
+
|
|
565
|
+
parsed_output = rejection_sampler.parse_output(
|
|
566
|
+
output_token_ids,
|
|
567
|
+
vocab_size,
|
|
568
|
+
num_draft_tokens_cpu=np.asarray(num_draft_tokens),
|
|
569
|
+
batch_size=len(num_draft_tokens),
|
|
570
|
+
padded_tokens_length=int(sum(num_draft_tokens)))
|
|
571
|
+
|
|
572
|
+
expected = [[10, 20]] # Invalid tokens filtered out
|
|
573
|
+
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
|
|
574
|
+
|
|
575
|
+
def test_parse_output_empty_sequences(self, rejection_sampler):
|
|
576
|
+
"""Test parse_output with empty sequences."""
|
|
577
|
+
vocab_size = VOCAB_SIZE
|
|
578
|
+
|
|
579
|
+
# Test with empty sequences
|
|
580
|
+
main_tokens = jnp.array([], dtype=jnp.int32)
|
|
581
|
+
bonus_tokens = jnp.array([50, 60], dtype=jnp.int32)
|
|
582
|
+
output_token_ids = jnp.concatenate([main_tokens, bonus_tokens])
|
|
583
|
+
|
|
584
|
+
num_draft_tokens = jnp.array([0, 0], dtype=jnp.int32)
|
|
585
|
+
|
|
586
|
+
parsed_output = rejection_sampler.parse_output(
|
|
587
|
+
output_token_ids,
|
|
588
|
+
vocab_size,
|
|
589
|
+
num_draft_tokens_cpu=np.asarray(num_draft_tokens),
|
|
590
|
+
batch_size=len(num_draft_tokens),
|
|
591
|
+
padded_tokens_length=int(sum(num_draft_tokens)))
|
|
592
|
+
|
|
593
|
+
expected = [[50], [60]] # Only bonus tokens
|
|
594
|
+
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
|
|
595
|
+
|
|
596
|
+
# =============== Padding-Specific Tests ===============
|
|
597
|
+
|
|
598
|
+
def test_padding_ignored_correctly(self, rejection_sampler, test_factory):
|
|
599
|
+
"""Test that padding tokens are completely ignored."""
|
|
600
|
+
# Both versions should produce identical results
|
|
601
|
+
test_cases = test_factory.create_with_padding_variant(
|
|
602
|
+
name="padding_test",
|
|
603
|
+
draft_tokens=[1, 2],
|
|
604
|
+
target_tokens=[1, 5],
|
|
605
|
+
num_draft_per_seq=[2],
|
|
606
|
+
bonus_tokens=[3],
|
|
607
|
+
expected=[[1, 5]],
|
|
608
|
+
description="Test padding is ignored")
|
|
609
|
+
|
|
610
|
+
for test_case in test_cases:
|
|
611
|
+
RejectionSamplerTestHelper.run_rejection_sampler_test(
|
|
612
|
+
rejection_sampler, test_case)
|
|
613
|
+
|
|
614
|
+
def test_extreme_padding(self, rejection_sampler, test_helper):
|
|
615
|
+
"""Test with extreme padding (much longer than actual tokens)."""
|
|
616
|
+
metadata = test_helper.create_sampling_metadata(all_greedy=True)
|
|
617
|
+
|
|
618
|
+
# Create heavily padded input: [1, 2] + 20 padding tokens
|
|
619
|
+
draft_tokens_with_extreme_padding = [1, 2] + [PAD_TOKEN_ID] * 20
|
|
620
|
+
padded_draft_tokens = jnp.array(draft_tokens_with_extreme_padding,
|
|
621
|
+
dtype=jnp.int32)
|
|
622
|
+
|
|
623
|
+
# Extract only the actual tokens (first 2)
|
|
624
|
+
num_draft_tokens = jnp.array([2], dtype=jnp.int32)
|
|
625
|
+
total_actual_tokens = int(jnp.sum(num_draft_tokens))
|
|
626
|
+
draft_token_ids = padded_draft_tokens[:total_actual_tokens]
|
|
627
|
+
|
|
628
|
+
target_logits = test_helper.create_target_logits_from_tokens(
|
|
629
|
+
[1, 5], VOCAB_SIZE)
|
|
630
|
+
bonus_token_ids = jnp.array([3], dtype=jnp.int32)
|
|
631
|
+
|
|
632
|
+
output = rejection_sampler(
|
|
633
|
+
draft_token_ids=draft_token_ids,
|
|
634
|
+
num_draft_tokens=num_draft_tokens,
|
|
635
|
+
draft_probs=None,
|
|
636
|
+
target_logits=target_logits,
|
|
637
|
+
bonus_token_ids=bonus_token_ids,
|
|
638
|
+
sampling_metadata=metadata,
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
parsed_output = rejection_sampler.parse_output(
|
|
642
|
+
output,
|
|
643
|
+
VOCAB_SIZE,
|
|
644
|
+
num_draft_tokens_cpu=np.asarray(num_draft_tokens),
|
|
645
|
+
batch_size=len(num_draft_tokens),
|
|
646
|
+
padded_tokens_length=int(sum(num_draft_tokens)))
|
|
647
|
+
|
|
648
|
+
expected = [[1, 5]] # Should ignore all padding
|
|
649
|
+
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
|
|
650
|
+
|
|
651
|
+
def test_realistic_flattened_with_padding(self, rejection_sampler,
|
|
652
|
+
test_factory):
|
|
653
|
+
"""Test with realistic flattened input including padding."""
|
|
654
|
+
test_case = test_factory.create_test_case(
|
|
655
|
+
name="realistic_flattened_with_padding",
|
|
656
|
+
draft_tokens=[1, 2, 3],
|
|
657
|
+
target_tokens=[1, 5, 3],
|
|
658
|
+
num_draft_per_seq=[2, 1],
|
|
659
|
+
bonus_tokens=[6, 7],
|
|
660
|
+
expected=[[1, 5], [3, 7]],
|
|
661
|
+
description="Realistic flattened input with padding",
|
|
662
|
+
use_padding=True)
|
|
663
|
+
RejectionSamplerTestHelper.run_rejection_sampler_test(
|
|
664
|
+
rejection_sampler, test_case)
|
|
665
|
+
|
|
666
|
+
# =============== Segment Operation Edge Case Tests ===============
|
|
667
|
+
|
|
668
|
+
def test_all_sequences_immediate_mismatch(self, rejection_sampler,
|
|
669
|
+
test_factory):
|
|
670
|
+
"""Test where all sequences have immediate mismatches (first token rejected)."""
|
|
671
|
+
test_cases = test_factory.create_with_padding_variant(
|
|
672
|
+
name="all_immediate_mismatch",
|
|
673
|
+
draft_tokens=[1, 2, 3, 4, 5, 6, 7, 8, 9],
|
|
674
|
+
target_tokens=[10, 2, 3, 11, 5, 6, 12, 8,
|
|
675
|
+
9], # All first tokens mismatch
|
|
676
|
+
num_draft_per_seq=[3, 3, 3],
|
|
677
|
+
bonus_tokens=[20, 21, 22],
|
|
678
|
+
expected=[[10], [11], [12]], # Only correction tokens, no bonus
|
|
679
|
+
description="All sequences have immediate first token mismatch")
|
|
680
|
+
|
|
681
|
+
for test_case in test_cases:
|
|
682
|
+
RejectionSamplerTestHelper.run_rejection_sampler_test(
|
|
683
|
+
rejection_sampler, test_case)
|
|
684
|
+
|
|
685
|
+
def test_all_sequences_perfect_match(self, rejection_sampler,
|
|
686
|
+
test_factory):
|
|
687
|
+
"""Test where all sequences have perfect matches (all tokens accepted)."""
|
|
688
|
+
test_cases = test_factory.create_with_padding_variant(
|
|
689
|
+
name="all_perfect_match",
|
|
690
|
+
draft_tokens=[1, 2, 3, 4, 5, 6, 7, 8, 9],
|
|
691
|
+
target_tokens=[1, 2, 3, 4, 5, 6, 7, 8,
|
|
692
|
+
9], # All tokens match perfectly
|
|
693
|
+
num_draft_per_seq=[3, 3, 3],
|
|
694
|
+
bonus_tokens=[10, 11, 12],
|
|
695
|
+
expected=[[1, 2, 3, 10], [4, 5, 6, 11],
|
|
696
|
+
[7, 8, 9, 12]], # All accepted + bonus
|
|
697
|
+
description="All sequences have perfect token matches")
|
|
698
|
+
|
|
699
|
+
for test_case in test_cases:
|
|
700
|
+
RejectionSamplerTestHelper.run_rejection_sampler_test(
|
|
701
|
+
rejection_sampler, test_case)
|
|
702
|
+
|
|
703
|
+
def test_extreme_length_imbalance(self, rejection_sampler, test_factory):
|
|
704
|
+
"""Test with extreme length imbalance between sequences."""
|
|
705
|
+
# One very long sequence (15 tokens) with others being short (1-2 tokens)
|
|
706
|
+
test_case = test_factory.create_test_case(
|
|
707
|
+
name="extreme_length_imbalance",
|
|
708
|
+
draft_tokens=[
|
|
709
|
+
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18
|
|
710
|
+
],
|
|
711
|
+
target_tokens=[
|
|
712
|
+
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 18
|
|
713
|
+
],
|
|
714
|
+
num_draft_per_seq=[15, 1, 2], # Very imbalanced lengths
|
|
715
|
+
bonus_tokens=[100, 101, 102],
|
|
716
|
+
expected=[
|
|
717
|
+
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
|
|
718
|
+
100], # All 15 accepted + bonus
|
|
719
|
+
[16, 101], # Single token accepted + bonus
|
|
720
|
+
[20]
|
|
721
|
+
], # First token mismatch, no bonus
|
|
722
|
+
description="Extreme length imbalance between sequences")
|
|
723
|
+
RejectionSamplerTestHelper.run_rejection_sampler_test(
|
|
724
|
+
rejection_sampler, test_case)
|
|
725
|
+
|
|
726
|
+
def test_mixed_accept_reject_patterns(self, rejection_sampler,
|
|
727
|
+
test_factory):
|
|
728
|
+
"""Test mixed scenarios with perfect matches and immediate rejections."""
|
|
729
|
+
test_cases = test_factory.create_with_padding_variant(
|
|
730
|
+
name="mixed_accept_reject",
|
|
731
|
+
draft_tokens=[1, 2, 3, 4, 5, 6, 7, 8, 9],
|
|
732
|
+
target_tokens=[
|
|
733
|
+
1, 2, 3, 10, 5, 6, 7, 8, 9
|
|
734
|
+
], # First: perfect, Second: immediate reject, Third: perfect
|
|
735
|
+
num_draft_per_seq=[3, 3, 3],
|
|
736
|
+
bonus_tokens=[20, 21, 22],
|
|
737
|
+
expected=[[1, 2, 3, 20], [10], [7, 8, 9, 22]], # Mixed results
|
|
738
|
+
description="Mix of perfect matches and immediate rejections")
|
|
739
|
+
|
|
740
|
+
for test_case in test_cases:
|
|
741
|
+
RejectionSamplerTestHelper.run_rejection_sampler_test(
|
|
742
|
+
rejection_sampler, test_case)
|
|
743
|
+
|
|
744
|
+
def test_mismatches_at_same_position(self, rejection_sampler,
|
|
745
|
+
test_factory):
|
|
746
|
+
"""Test where mismatches occur at exactly the same position across sequences."""
|
|
747
|
+
test_cases = test_factory.create_with_padding_variant(
|
|
748
|
+
name="same_position_mismatch",
|
|
749
|
+
draft_tokens=[1, 2, 3, 4, 5, 6, 7, 8, 9],
|
|
750
|
+
target_tokens=[1, 10, 3, 4, 11, 6, 7, 12,
|
|
751
|
+
9], # All mismatch at position 1 (middle token)
|
|
752
|
+
num_draft_per_seq=[3, 3, 3],
|
|
753
|
+
bonus_tokens=[20, 21, 22],
|
|
754
|
+
expected=[[1, 10], [4, 11], [7,
|
|
755
|
+
12]], # All reject at same position
|
|
756
|
+
description="Mismatches at same position in all sequences")
|
|
757
|
+
|
|
758
|
+
for test_case in test_cases:
|
|
759
|
+
RejectionSamplerTestHelper.run_rejection_sampler_test(
|
|
760
|
+
rejection_sampler, test_case)
|
|
761
|
+
|
|
762
|
+
def test_single_long_sequence(self, rejection_sampler, test_helper):
|
|
763
|
+
"""Test a single very long sequence (approaching MAX_SPEC_LEN)."""
|
|
764
|
+
metadata = test_helper.create_sampling_metadata(all_greedy=True)
|
|
765
|
+
|
|
766
|
+
# Create a sequence with 30 draft tokens (close to MAX_SPEC_LEN=32)
|
|
767
|
+
draft_tokens = list(range(1, 31))
|
|
768
|
+
target_tokens = list(range(1, 28)) + [99, 29, 30
|
|
769
|
+
] # Mismatch at position 27
|
|
770
|
+
|
|
771
|
+
draft_token_ids = jnp.array(draft_tokens, dtype=jnp.int32)
|
|
772
|
+
target_logits = test_helper.create_target_logits_from_tokens(
|
|
773
|
+
target_tokens, VOCAB_SIZE)
|
|
774
|
+
num_draft_tokens = jnp.array([30], dtype=jnp.int32)
|
|
775
|
+
bonus_token_ids = jnp.array([100], dtype=jnp.int32)
|
|
776
|
+
|
|
777
|
+
output = rejection_sampler(
|
|
778
|
+
draft_token_ids=draft_token_ids,
|
|
779
|
+
num_draft_tokens=num_draft_tokens,
|
|
780
|
+
draft_probs=None,
|
|
781
|
+
target_logits=target_logits,
|
|
782
|
+
bonus_token_ids=bonus_token_ids,
|
|
783
|
+
sampling_metadata=metadata,
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
parsed_output = rejection_sampler.parse_output(
|
|
787
|
+
output,
|
|
788
|
+
VOCAB_SIZE,
|
|
789
|
+
num_draft_tokens_cpu=np.asarray(num_draft_tokens),
|
|
790
|
+
batch_size=len(num_draft_tokens),
|
|
791
|
+
padded_tokens_length=int(sum(num_draft_tokens)))
|
|
792
|
+
|
|
793
|
+
expected = [list(range(1, 28)) + [99]] # Tokens up to mismatch point
|
|
794
|
+
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
|
|
795
|
+
|
|
796
|
+
|
|
797
|
+
# ======================== NON-GREEDY SAMPLING TESTS ========================
|
|
798
|
+
|
|
799
|
+
|
|
800
|
+
class TestNonGreedyRejectionSampler:
|
|
801
|
+
"""Test suite for non-greedy (random) rejection sampling."""
|
|
802
|
+
|
|
803
|
+
def test_non_greedy_basic_functionality(self, rejection_sampler,
|
|
804
|
+
test_helper):
|
|
805
|
+
"""Test basic non-greedy sampling functionality."""
|
|
806
|
+
metadata = test_helper.create_sampling_metadata(all_greedy=False)
|
|
807
|
+
|
|
808
|
+
# Create simple test case
|
|
809
|
+
draft_tokens = [10, 20, 30]
|
|
810
|
+
target_tokens = [10, 50, 30] # Mismatch at position 1
|
|
811
|
+
|
|
812
|
+
draft_token_ids = jnp.array(draft_tokens, dtype=jnp.int32)
|
|
813
|
+
target_logits = test_helper.create_target_logits_from_tokens(
|
|
814
|
+
target_tokens, VOCAB_SIZE)
|
|
815
|
+
|
|
816
|
+
# Create draft probabilities - make draft tokens highly likely
|
|
817
|
+
draft_probs = jnp.full((len(draft_tokens), VOCAB_SIZE),
|
|
818
|
+
-100.0,
|
|
819
|
+
dtype=jnp.float32)
|
|
820
|
+
for i, token_id in enumerate(draft_tokens):
|
|
821
|
+
draft_probs = draft_probs.at[i, token_id].set(100.0)
|
|
822
|
+
|
|
823
|
+
# Convert logits to probabilities for draft_probs
|
|
824
|
+
draft_probs = jax.nn.softmax(draft_probs, axis=-1)
|
|
825
|
+
|
|
826
|
+
num_draft_tokens = jnp.array([3], dtype=jnp.int32)
|
|
827
|
+
bonus_token_ids = jnp.array([99], dtype=jnp.int32)
|
|
828
|
+
key = jax.random.PRNGKey(42)
|
|
829
|
+
|
|
830
|
+
output = rejection_sampler(
|
|
831
|
+
draft_token_ids=draft_token_ids,
|
|
832
|
+
num_draft_tokens=num_draft_tokens,
|
|
833
|
+
draft_probs=draft_probs,
|
|
834
|
+
target_logits=target_logits,
|
|
835
|
+
bonus_token_ids=bonus_token_ids,
|
|
836
|
+
sampling_metadata=metadata,
|
|
837
|
+
key=key,
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
parsed_output = rejection_sampler.parse_output(
|
|
841
|
+
output,
|
|
842
|
+
VOCAB_SIZE,
|
|
843
|
+
num_draft_tokens_cpu=np.asarray(num_draft_tokens),
|
|
844
|
+
batch_size=1,
|
|
845
|
+
padded_tokens_length=3)
|
|
846
|
+
|
|
847
|
+
# For non-greedy sampling, exact output depends on random sampling
|
|
848
|
+
# but we can check that the first token should be accepted
|
|
849
|
+
assert len(parsed_output) == 1
|
|
850
|
+
assert len(parsed_output[0]) >= 1
|
|
851
|
+
assert parsed_output[0][0] == 10 # First token should match
|
|
852
|
+
|
|
853
|
+
def test_non_greedy_deterministic_with_seed(self, rejection_sampler,
|
|
854
|
+
test_helper):
|
|
855
|
+
"""Test that non-greedy sampling is deterministic with the same seed."""
|
|
856
|
+
metadata = test_helper.create_sampling_metadata(all_greedy=False)
|
|
857
|
+
|
|
858
|
+
# Create test case
|
|
859
|
+
draft_tokens = [1, 2, 3, 4]
|
|
860
|
+
target_tokens = [1, 5, 3, 6] # Mismatches at positions 1 and 3
|
|
861
|
+
|
|
862
|
+
draft_token_ids = jnp.array(draft_tokens, dtype=jnp.int32)
|
|
863
|
+
target_logits = test_helper.create_target_logits_from_tokens(
|
|
864
|
+
target_tokens, VOCAB_SIZE)
|
|
865
|
+
|
|
866
|
+
# Create draft probabilities
|
|
867
|
+
draft_probs = jnp.full((len(draft_tokens), VOCAB_SIZE),
|
|
868
|
+
-100.0,
|
|
869
|
+
dtype=jnp.float32)
|
|
870
|
+
for i, token_id in enumerate(draft_tokens):
|
|
871
|
+
draft_probs = draft_probs.at[i, token_id].set(100.0)
|
|
872
|
+
|
|
873
|
+
# Convert logits to probabilities for draft_probs
|
|
874
|
+
draft_probs = jax.nn.softmax(draft_probs, axis=-1)
|
|
875
|
+
|
|
876
|
+
num_draft_tokens = jnp.array([4], dtype=jnp.int32)
|
|
877
|
+
bonus_token_ids = jnp.array([99], dtype=jnp.int32)
|
|
878
|
+
|
|
879
|
+
# Run with same seed multiple times
|
|
880
|
+
key = jax.random.PRNGKey(123)
|
|
881
|
+
outputs = []
|
|
882
|
+
|
|
883
|
+
for _ in range(5):
|
|
884
|
+
output = rejection_sampler(
|
|
885
|
+
draft_token_ids=draft_token_ids,
|
|
886
|
+
num_draft_tokens=num_draft_tokens,
|
|
887
|
+
draft_probs=draft_probs,
|
|
888
|
+
target_logits=target_logits,
|
|
889
|
+
bonus_token_ids=bonus_token_ids,
|
|
890
|
+
sampling_metadata=metadata,
|
|
891
|
+
key=key,
|
|
892
|
+
)
|
|
893
|
+
|
|
894
|
+
parsed_output = rejection_sampler.parse_output(
|
|
895
|
+
output,
|
|
896
|
+
VOCAB_SIZE,
|
|
897
|
+
num_draft_tokens_cpu=np.asarray(num_draft_tokens),
|
|
898
|
+
batch_size=1,
|
|
899
|
+
padded_tokens_length=4)
|
|
900
|
+
outputs.append(parsed_output)
|
|
901
|
+
|
|
902
|
+
# All outputs should be identical with same seed
|
|
903
|
+
for i in range(1, len(outputs)):
|
|
904
|
+
assert outputs[i] == outputs[
|
|
905
|
+
0], f"Run {i}: {outputs[i]} != {outputs[0]}"
|
|
906
|
+
|
|
907
|
+
def test_non_greedy_with_draft_probs_none(self, rejection_sampler,
|
|
908
|
+
test_helper):
|
|
909
|
+
"""Test non-greedy sampling when draft_probs is None."""
|
|
910
|
+
metadata = test_helper.create_sampling_metadata(all_greedy=False)
|
|
911
|
+
|
|
912
|
+
# Create test case
|
|
913
|
+
draft_tokens = [15, 25]
|
|
914
|
+
target_tokens = [15, 35] # Mismatch at position 1
|
|
915
|
+
|
|
916
|
+
draft_token_ids = jnp.array(draft_tokens, dtype=jnp.int32)
|
|
917
|
+
target_logits = test_helper.create_target_logits_from_tokens(
|
|
918
|
+
target_tokens, VOCAB_SIZE)
|
|
919
|
+
|
|
920
|
+
num_draft_tokens = jnp.array([2], dtype=jnp.int32)
|
|
921
|
+
bonus_token_ids = jnp.array([88], dtype=jnp.int32)
|
|
922
|
+
key = jax.random.PRNGKey(777)
|
|
923
|
+
|
|
924
|
+
output = rejection_sampler(
|
|
925
|
+
draft_token_ids=draft_token_ids,
|
|
926
|
+
num_draft_tokens=num_draft_tokens,
|
|
927
|
+
draft_probs=None, # No draft probabilities
|
|
928
|
+
target_logits=target_logits,
|
|
929
|
+
bonus_token_ids=bonus_token_ids,
|
|
930
|
+
sampling_metadata=metadata,
|
|
931
|
+
key=key,
|
|
932
|
+
)
|
|
933
|
+
|
|
934
|
+
parsed_output = rejection_sampler.parse_output(
|
|
935
|
+
output,
|
|
936
|
+
VOCAB_SIZE,
|
|
937
|
+
num_draft_tokens_cpu=np.asarray(num_draft_tokens),
|
|
938
|
+
batch_size=1,
|
|
939
|
+
padded_tokens_length=2)
|
|
940
|
+
|
|
941
|
+
# Should have valid output
|
|
942
|
+
assert len(parsed_output) == 1
|
|
943
|
+
assert len(parsed_output[0]) >= 1
|
|
944
|
+
assert parsed_output[0][0] == 15 # First token should match
|
|
945
|
+
|
|
946
|
+
def test_non_greedy_multiple_sequences(self, rejection_sampler,
|
|
947
|
+
test_helper):
|
|
948
|
+
"""Test non-greedy sampling with multiple sequences."""
|
|
949
|
+
metadata = test_helper.create_sampling_metadata(all_greedy=False)
|
|
950
|
+
|
|
951
|
+
# Create test case with 3 sequences
|
|
952
|
+
draft_tokens = [1, 2, 3, 4, 5, 6, 7] # [1,2] [3,4,5] [6,7]
|
|
953
|
+
target_tokens = [1, 5, 3, 8, 5, 6,
|
|
954
|
+
9] # Mismatches at different positions
|
|
955
|
+
|
|
956
|
+
draft_token_ids = jnp.array(draft_tokens, dtype=jnp.int32)
|
|
957
|
+
target_logits = test_helper.create_target_logits_from_tokens(
|
|
958
|
+
target_tokens, VOCAB_SIZE)
|
|
959
|
+
|
|
960
|
+
# Create draft probabilities
|
|
961
|
+
draft_probs = jnp.full((len(draft_tokens), VOCAB_SIZE),
|
|
962
|
+
-100.0,
|
|
963
|
+
dtype=jnp.float32)
|
|
964
|
+
for i, token_id in enumerate(draft_tokens):
|
|
965
|
+
draft_probs = draft_probs.at[i, token_id].set(100.0)
|
|
966
|
+
|
|
967
|
+
# Convert logits to probabilities for draft_probs
|
|
968
|
+
draft_probs = jax.nn.softmax(draft_probs, axis=-1)
|
|
969
|
+
|
|
970
|
+
num_draft_tokens = jnp.array([2, 3, 2], dtype=jnp.int32)
|
|
971
|
+
bonus_token_ids = jnp.array([11, 12, 13], dtype=jnp.int32)
|
|
972
|
+
key = jax.random.PRNGKey(456)
|
|
973
|
+
|
|
974
|
+
output = rejection_sampler(
|
|
975
|
+
draft_token_ids=draft_token_ids,
|
|
976
|
+
num_draft_tokens=num_draft_tokens,
|
|
977
|
+
draft_probs=draft_probs,
|
|
978
|
+
target_logits=target_logits,
|
|
979
|
+
bonus_token_ids=bonus_token_ids,
|
|
980
|
+
sampling_metadata=metadata,
|
|
981
|
+
key=key,
|
|
982
|
+
)
|
|
983
|
+
|
|
984
|
+
parsed_output = rejection_sampler.parse_output(
|
|
985
|
+
output,
|
|
986
|
+
VOCAB_SIZE,
|
|
987
|
+
num_draft_tokens_cpu=np.asarray(num_draft_tokens),
|
|
988
|
+
batch_size=3,
|
|
989
|
+
padded_tokens_length=7)
|
|
990
|
+
|
|
991
|
+
# Should have 3 sequences
|
|
992
|
+
assert len(parsed_output) == 3
|
|
993
|
+
|
|
994
|
+
# First sequence: [1, 2] -> [1, 5] (mismatch at pos 1)
|
|
995
|
+
assert parsed_output[0][0] == 1
|
|
996
|
+
|
|
997
|
+
# Second sequence: [3, 4, 5] -> [3, 8, 5] (mismatch at pos 1)
|
|
998
|
+
assert parsed_output[1][0] == 3
|
|
999
|
+
|
|
1000
|
+
# Third sequence: [6, 7] -> [6, 9] (mismatch at pos 1)
|
|
1001
|
+
assert parsed_output[2][0] == 6
|
|
1002
|
+
|
|
1003
|
+
def test_non_greedy_with_all_accepted_tokens(self, rejection_sampler,
|
|
1004
|
+
test_helper):
|
|
1005
|
+
"""Test non-greedy sampling when all tokens are accepted (perfect match)."""
|
|
1006
|
+
metadata = test_helper.create_sampling_metadata(all_greedy=False)
|
|
1007
|
+
|
|
1008
|
+
# Perfect match case
|
|
1009
|
+
draft_tokens = [10, 20, 30]
|
|
1010
|
+
target_tokens = [10, 20, 30] # Perfect match
|
|
1011
|
+
|
|
1012
|
+
draft_token_ids = jnp.array(draft_tokens, dtype=jnp.int32)
|
|
1013
|
+
target_logits = test_helper.create_target_logits_from_tokens(
|
|
1014
|
+
target_tokens, VOCAB_SIZE)
|
|
1015
|
+
|
|
1016
|
+
# Create draft probabilities - make acceptance very likely
|
|
1017
|
+
draft_probs = jnp.full((len(draft_tokens), VOCAB_SIZE),
|
|
1018
|
+
-100.0,
|
|
1019
|
+
dtype=jnp.float32)
|
|
1020
|
+
for i, token_id in enumerate(draft_tokens):
|
|
1021
|
+
draft_probs = draft_probs.at[i, token_id].set(100.0)
|
|
1022
|
+
|
|
1023
|
+
# Convert logits to probabilities for draft_probs
|
|
1024
|
+
draft_probs = jax.nn.softmax(draft_probs, axis=-1)
|
|
1025
|
+
|
|
1026
|
+
num_draft_tokens = jnp.array([3], dtype=jnp.int32)
|
|
1027
|
+
bonus_token_ids = jnp.array([99], dtype=jnp.int32)
|
|
1028
|
+
key = jax.random.PRNGKey(999)
|
|
1029
|
+
|
|
1030
|
+
output = rejection_sampler(
|
|
1031
|
+
draft_token_ids=draft_token_ids,
|
|
1032
|
+
num_draft_tokens=num_draft_tokens,
|
|
1033
|
+
draft_probs=draft_probs,
|
|
1034
|
+
target_logits=target_logits,
|
|
1035
|
+
bonus_token_ids=bonus_token_ids,
|
|
1036
|
+
sampling_metadata=metadata,
|
|
1037
|
+
key=key,
|
|
1038
|
+
)
|
|
1039
|
+
|
|
1040
|
+
parsed_output = rejection_sampler.parse_output(
|
|
1041
|
+
output,
|
|
1042
|
+
VOCAB_SIZE,
|
|
1043
|
+
num_draft_tokens_cpu=np.asarray(num_draft_tokens),
|
|
1044
|
+
batch_size=1,
|
|
1045
|
+
padded_tokens_length=3)
|
|
1046
|
+
|
|
1047
|
+
# With perfect match and high acceptance probability, should get bonus token
|
|
1048
|
+
assert len(parsed_output) == 1
|
|
1049
|
+
# The exact output depends on random sampling, but should contain the draft tokens
|
|
1050
|
+
|
|
1051
|
+
def test_non_greedy_empty_sequence(self, rejection_sampler, test_helper):
|
|
1052
|
+
"""Test non-greedy sampling with empty sequences."""
|
|
1053
|
+
metadata = test_helper.create_sampling_metadata(all_greedy=False)
|
|
1054
|
+
|
|
1055
|
+
# Empty sequences should get bonus tokens
|
|
1056
|
+
draft_token_ids = jnp.array([], dtype=jnp.int32)
|
|
1057
|
+
target_logits = jnp.array([], dtype=jnp.float32).reshape(0, VOCAB_SIZE)
|
|
1058
|
+
|
|
1059
|
+
num_draft_tokens = jnp.array([0, 0], dtype=jnp.int32)
|
|
1060
|
+
bonus_token_ids = jnp.array([77, 88], dtype=jnp.int32)
|
|
1061
|
+
key = jax.random.PRNGKey(333)
|
|
1062
|
+
|
|
1063
|
+
output = rejection_sampler(
|
|
1064
|
+
draft_token_ids=draft_token_ids,
|
|
1065
|
+
num_draft_tokens=num_draft_tokens,
|
|
1066
|
+
draft_probs=None,
|
|
1067
|
+
target_logits=target_logits,
|
|
1068
|
+
bonus_token_ids=bonus_token_ids,
|
|
1069
|
+
sampling_metadata=metadata,
|
|
1070
|
+
key=key,
|
|
1071
|
+
)
|
|
1072
|
+
|
|
1073
|
+
parsed_output = rejection_sampler.parse_output(
|
|
1074
|
+
output,
|
|
1075
|
+
VOCAB_SIZE,
|
|
1076
|
+
num_draft_tokens_cpu=np.asarray(num_draft_tokens),
|
|
1077
|
+
batch_size=2,
|
|
1078
|
+
padded_tokens_length=0)
|
|
1079
|
+
|
|
1080
|
+
# Should get bonus tokens for empty sequences
|
|
1081
|
+
expected = [[77], [88]]
|
|
1082
|
+
assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
|
|
1083
|
+
|
|
1084
|
+
def test_non_greedy_requires_key(self, rejection_sampler, test_helper):
|
|
1085
|
+
"""Test that non-greedy sampling requires a random key."""
|
|
1086
|
+
metadata = test_helper.create_sampling_metadata(all_greedy=False)
|
|
1087
|
+
|
|
1088
|
+
# Create simple test case
|
|
1089
|
+
draft_tokens = [1, 2]
|
|
1090
|
+
target_tokens = [1, 3]
|
|
1091
|
+
|
|
1092
|
+
draft_token_ids = jnp.array(draft_tokens, dtype=jnp.int32)
|
|
1093
|
+
target_logits = test_helper.create_target_logits_from_tokens(
|
|
1094
|
+
target_tokens, VOCAB_SIZE)
|
|
1095
|
+
|
|
1096
|
+
num_draft_tokens = jnp.array([2], dtype=jnp.int32)
|
|
1097
|
+
bonus_token_ids = jnp.array([99], dtype=jnp.int32)
|
|
1098
|
+
|
|
1099
|
+
# Should raise ValueError when key is None for non-greedy sampling
|
|
1100
|
+
with pytest.raises(ValueError, match="A random key must be provided"):
|
|
1101
|
+
rejection_sampler(
|
|
1102
|
+
draft_token_ids=draft_token_ids,
|
|
1103
|
+
num_draft_tokens=num_draft_tokens,
|
|
1104
|
+
draft_probs=None,
|
|
1105
|
+
target_logits=target_logits,
|
|
1106
|
+
bonus_token_ids=bonus_token_ids,
|
|
1107
|
+
sampling_metadata=metadata,
|
|
1108
|
+
key=None, # No key provided
|
|
1109
|
+
)
|
|
1110
|
+
|
|
1111
|
+
def test_non_greedy_vs_greedy_same_perfect_case(self, rejection_sampler,
|
|
1112
|
+
test_helper):
|
|
1113
|
+
"""Test that greedy and non-greedy produce same results for perfect matches."""
|
|
1114
|
+
# Perfect match case - both should produce identical results
|
|
1115
|
+
draft_tokens = [5, 15, 25]
|
|
1116
|
+
target_tokens = [5, 15, 25] # Perfect match
|
|
1117
|
+
|
|
1118
|
+
draft_token_ids = jnp.array(draft_tokens, dtype=jnp.int32)
|
|
1119
|
+
target_logits = test_helper.create_target_logits_from_tokens(
|
|
1120
|
+
target_tokens, VOCAB_SIZE)
|
|
1121
|
+
|
|
1122
|
+
# Create draft probabilities
|
|
1123
|
+
draft_probs = jnp.full((len(draft_tokens), VOCAB_SIZE),
|
|
1124
|
+
-100.0,
|
|
1125
|
+
dtype=jnp.float32)
|
|
1126
|
+
for i, token_id in enumerate(draft_tokens):
|
|
1127
|
+
draft_probs = draft_probs.at[i, token_id].set(100.0)
|
|
1128
|
+
|
|
1129
|
+
# Convert logits to probabilities for draft_probs
|
|
1130
|
+
draft_probs = jax.nn.softmax(draft_probs, axis=-1)
|
|
1131
|
+
|
|
1132
|
+
num_draft_tokens = jnp.array([3], dtype=jnp.int32)
|
|
1133
|
+
bonus_token_ids = jnp.array([99], dtype=jnp.int32)
|
|
1134
|
+
|
|
1135
|
+
# Greedy sampling
|
|
1136
|
+
greedy_metadata = test_helper.create_sampling_metadata(all_greedy=True)
|
|
1137
|
+
greedy_output = rejection_sampler(
|
|
1138
|
+
draft_token_ids=draft_token_ids,
|
|
1139
|
+
num_draft_tokens=num_draft_tokens,
|
|
1140
|
+
draft_probs=draft_probs,
|
|
1141
|
+
target_logits=target_logits,
|
|
1142
|
+
bonus_token_ids=bonus_token_ids,
|
|
1143
|
+
sampling_metadata=greedy_metadata,
|
|
1144
|
+
)
|
|
1145
|
+
|
|
1146
|
+
# Non-greedy sampling with high acceptance probability should behave similarly
|
|
1147
|
+
# Note: Due to probabilistic nature, we can't guarantee identical outputs
|
|
1148
|
+
# but for perfect matches with high probabilities, they should be very similar
|
|
1149
|
+
non_greedy_metadata = test_helper.create_sampling_metadata(
|
|
1150
|
+
all_greedy=False)
|
|
1151
|
+
key = jax.random.PRNGKey(555)
|
|
1152
|
+
non_greedy_output = rejection_sampler(
|
|
1153
|
+
draft_token_ids=draft_token_ids,
|
|
1154
|
+
num_draft_tokens=num_draft_tokens,
|
|
1155
|
+
draft_probs=draft_probs,
|
|
1156
|
+
target_logits=target_logits,
|
|
1157
|
+
bonus_token_ids=bonus_token_ids,
|
|
1158
|
+
sampling_metadata=non_greedy_metadata,
|
|
1159
|
+
key=key,
|
|
1160
|
+
)
|
|
1161
|
+
|
|
1162
|
+
# Parse outputs
|
|
1163
|
+
greedy_parsed = rejection_sampler.parse_output(
|
|
1164
|
+
greedy_output, VOCAB_SIZE, np.asarray(num_draft_tokens), 1, 3)
|
|
1165
|
+
non_greedy_parsed = rejection_sampler.parse_output(
|
|
1166
|
+
non_greedy_output, VOCAB_SIZE, np.asarray(num_draft_tokens), 1, 3)
|
|
1167
|
+
|
|
1168
|
+
# For perfect match, greedy should have all tokens + bonus
|
|
1169
|
+
assert greedy_parsed == [[5, 15, 25, 99]]
|
|
1170
|
+
|
|
1171
|
+
# Non-greedy should have valid output (exact content may vary due to sampling)
|
|
1172
|
+
assert len(non_greedy_parsed) == 1
|
|
1173
|
+
assert len(non_greedy_parsed[0]) >= 1
|
|
1174
|
+
|
|
1175
|
+
|
|
1176
|
+
# ======================== STATISTICAL DISTRIBUTION VALIDATION ========================
|
|
1177
|
+
|
|
1178
|
+
|
|
1179
|
+
class TestStatisticalDistributionValidation:
|
|
1180
|
+
"""Test suite for validating rejection sampling produces correct probability distributions."""
|
|
1181
|
+
|
|
1182
|
+
def test_rejection_sampling_approximates_target_distribution(self):
|
|
1183
|
+
"""Verify rejection sampling approximates target distribution.
|
|
1184
|
+
|
|
1185
|
+
This test validates that rejection sampling produces the correct probability
|
|
1186
|
+
distribution despite sampling from a potentially distinct draft distribution.
|
|
1187
|
+
|
|
1188
|
+
The test works by:
|
|
1189
|
+
1. Creating random target and draft probability distributions
|
|
1190
|
+
2. Using rejection sampling to generate token samples
|
|
1191
|
+
3. Estimating the output distribution from samples
|
|
1192
|
+
4. Comparing convergence to target vs random reference distributions
|
|
1193
|
+
|
|
1194
|
+
We expect that as sample size increases, the distance to the target
|
|
1195
|
+
distribution decreases much more than the distance to random distributions.
|
|
1196
|
+
"""
|
|
1197
|
+
|
|
1198
|
+
vocab_size = 10
|
|
1199
|
+
k = 2
|
|
1200
|
+
num_reference_probs = 100
|
|
1201
|
+
|
|
1202
|
+
# Create random distributions
|
|
1203
|
+
key = jax.random.PRNGKey(42)
|
|
1204
|
+
draft_key, target_key, reference_key = jax.random.split(key, 3)
|
|
1205
|
+
|
|
1206
|
+
# Draft and target distributions
|
|
1207
|
+
draft_logits = jax.random.normal(draft_key, (vocab_size, ))
|
|
1208
|
+
draft_probs = jax.nn.softmax(draft_logits)
|
|
1209
|
+
|
|
1210
|
+
target_logits = jax.random.normal(target_key, (vocab_size, ))
|
|
1211
|
+
target_probs = jax.nn.softmax(target_logits)
|
|
1212
|
+
|
|
1213
|
+
# Reference distributions for comparison
|
|
1214
|
+
reference_logits = jax.random.normal(reference_key,
|
|
1215
|
+
(num_reference_probs, vocab_size))
|
|
1216
|
+
reference_probs = jax.nn.softmax(reference_logits, axis=-1)
|
|
1217
|
+
|
|
1218
|
+
sample_sizes = [10, 100, 1_000, 10_000, 100_000]
|
|
1219
|
+
distance_wrt_reference: List[float] = []
|
|
1220
|
+
distance_wrt_target: List[float] = []
|
|
1221
|
+
|
|
1222
|
+
for num_samples in sample_sizes:
|
|
1223
|
+
# Estimate rejection sampling distribution
|
|
1224
|
+
estimated_probs = self._estimate_rejection_sampling_pdf(
|
|
1225
|
+
draft_probs, target_logits, k, vocab_size, num_samples)
|
|
1226
|
+
|
|
1227
|
+
# Calculate distances
|
|
1228
|
+
reference_vs_rejsample_dist = float(
|
|
1229
|
+
jnp.mean(
|
|
1230
|
+
jnp.linalg.norm(reference_probs - estimated_probs[None, :],
|
|
1231
|
+
axis=-1)))
|
|
1232
|
+
target_vs_rejsample_dist = float(
|
|
1233
|
+
jnp.linalg.norm(target_probs - estimated_probs))
|
|
1234
|
+
|
|
1235
|
+
distance_wrt_reference.append(reference_vs_rejsample_dist)
|
|
1236
|
+
distance_wrt_target.append(target_vs_rejsample_dist)
|
|
1237
|
+
|
|
1238
|
+
print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
|
|
1239
|
+
f"{reference_vs_rejsample_dist=:.05f}")
|
|
1240
|
+
|
|
1241
|
+
# Calculate relative improvements
|
|
1242
|
+
relative_change_target = self._get_ratio_first_to_last(
|
|
1243
|
+
distance_wrt_target)
|
|
1244
|
+
relative_change_reference = self._get_ratio_first_to_last(
|
|
1245
|
+
distance_wrt_reference)
|
|
1246
|
+
|
|
1247
|
+
print(f"Target improvement ratio: {relative_change_target:.02f}")
|
|
1248
|
+
print(f"Reference improvement ratio: {relative_change_reference:.02f}")
|
|
1249
|
+
|
|
1250
|
+
# Validation: Target distribution should converge much better than reference
|
|
1251
|
+
expected_improvement_multiplier = 20
|
|
1252
|
+
assert (relative_change_target >
|
|
1253
|
+
relative_change_reference * expected_improvement_multiplier), \
|
|
1254
|
+
f"Target convergence ({relative_change_target:.2f}) should be " \
|
|
1255
|
+
f"{expected_improvement_multiplier}x better than reference " \
|
|
1256
|
+
f"({relative_change_reference:.2f})"
|
|
1257
|
+
|
|
1258
|
+
def _estimate_rejection_sampling_pdf(
|
|
1259
|
+
self,
|
|
1260
|
+
draft_probs: jnp.ndarray,
|
|
1261
|
+
target_logits: jnp.ndarray,
|
|
1262
|
+
k: int,
|
|
1263
|
+
vocab_size: int,
|
|
1264
|
+
num_samples: int,
|
|
1265
|
+
) -> jnp.ndarray:
|
|
1266
|
+
"""Estimate probability distribution of rejection sampling output.
|
|
1267
|
+
|
|
1268
|
+
Args:
|
|
1269
|
+
draft_probs: Draft probability distribution [vocab_size]
|
|
1270
|
+
target_logits: Target logits [vocab_size]
|
|
1271
|
+
k: Number of draft tokens per sequence
|
|
1272
|
+
vocab_size: Size of vocabulary
|
|
1273
|
+
num_samples: Number of samples to generate
|
|
1274
|
+
|
|
1275
|
+
Returns:
|
|
1276
|
+
Estimated probability distribution [vocab_size]
|
|
1277
|
+
"""
|
|
1278
|
+
rejection_sampler = RejectionSampler()
|
|
1279
|
+
|
|
1280
|
+
# Prepare inputs in the flattened format expected by TPU sampler
|
|
1281
|
+
num_tokens = num_samples * k
|
|
1282
|
+
|
|
1283
|
+
# Expand draft probs to match flattened format [num_tokens, vocab_size]
|
|
1284
|
+
draft_probs_expanded = jnp.tile(draft_probs[None, :], (num_tokens, 1))
|
|
1285
|
+
|
|
1286
|
+
# Expand target logits to flattened format
|
|
1287
|
+
target_logits_expanded = jnp.tile(target_logits[None, :],
|
|
1288
|
+
(num_tokens, 1))
|
|
1289
|
+
|
|
1290
|
+
# Generate random draft token ids from draft distribution
|
|
1291
|
+
key = jax.random.PRNGKey(123)
|
|
1292
|
+
draft_tokens_2d = jax.random.categorical(key,
|
|
1293
|
+
jnp.log(draft_probs + 1e-8),
|
|
1294
|
+
shape=(num_samples, k))
|
|
1295
|
+
draft_token_ids = draft_tokens_2d.flatten()
|
|
1296
|
+
|
|
1297
|
+
# Prepare other inputs
|
|
1298
|
+
num_draft_tokens = jnp.full((num_samples, ), k, dtype=jnp.int32)
|
|
1299
|
+
bonus_token_ids = jnp.zeros((num_samples, ),
|
|
1300
|
+
dtype=jnp.int32) # Not used in estimation
|
|
1301
|
+
|
|
1302
|
+
# Create sampling metadata for non-greedy sampling
|
|
1303
|
+
sampling_metadata = TPUSupportedSamplingMetadata(
|
|
1304
|
+
do_sampling=True, # Non-greedy sampling
|
|
1305
|
+
logprobs=False,
|
|
1306
|
+
top_k=jnp.full((num_samples, ), -1, dtype=jnp.int32),
|
|
1307
|
+
top_p=jnp.full((num_samples, ), 1.0, dtype=jnp.float32),
|
|
1308
|
+
temperature=jnp.full((num_samples, ), 1.0, dtype=jnp.float32),
|
|
1309
|
+
)
|
|
1310
|
+
|
|
1311
|
+
# Run rejection sampling
|
|
1312
|
+
sample_key = jax.random.PRNGKey(456)
|
|
1313
|
+
output_token_ids = rejection_sampler(
|
|
1314
|
+
draft_token_ids=draft_token_ids,
|
|
1315
|
+
num_draft_tokens=num_draft_tokens,
|
|
1316
|
+
draft_probs=draft_probs_expanded,
|
|
1317
|
+
target_logits=target_logits_expanded,
|
|
1318
|
+
bonus_token_ids=bonus_token_ids,
|
|
1319
|
+
sampling_metadata=sampling_metadata,
|
|
1320
|
+
key=sample_key,
|
|
1321
|
+
)
|
|
1322
|
+
|
|
1323
|
+
# Parse output and extract main tokens (exclude bonus tokens)
|
|
1324
|
+
parsed_output = rejection_sampler.parse_output(
|
|
1325
|
+
output_token_ids,
|
|
1326
|
+
vocab_size=vocab_size,
|
|
1327
|
+
num_draft_tokens_cpu=np.asarray(num_draft_tokens),
|
|
1328
|
+
batch_size=num_samples,
|
|
1329
|
+
padded_tokens_length=num_tokens)
|
|
1330
|
+
|
|
1331
|
+
# Flatten all main tokens (exclude bonus tokens)
|
|
1332
|
+
all_tokens = []
|
|
1333
|
+
for seq_tokens in parsed_output:
|
|
1334
|
+
if len(seq_tokens) == 0:
|
|
1335
|
+
continue
|
|
1336
|
+
# For rejection sampling, we need to exclude bonus tokens
|
|
1337
|
+
# The bonus token is typically the last one if all draft tokens were accepted
|
|
1338
|
+
# Otherwise, we take all valid tokens up to the rejection point
|
|
1339
|
+
if len(seq_tokens) > k:
|
|
1340
|
+
# More tokens than expected draft tokens means bonus token included
|
|
1341
|
+
main_tokens = seq_tokens[:k]
|
|
1342
|
+
else:
|
|
1343
|
+
# No bonus token, take all tokens
|
|
1344
|
+
main_tokens = seq_tokens
|
|
1345
|
+
all_tokens.extend(main_tokens)
|
|
1346
|
+
|
|
1347
|
+
# Convert to numpy for histogram computation
|
|
1348
|
+
if not all_tokens:
|
|
1349
|
+
# Fallback if no tokens generated
|
|
1350
|
+
return jnp.ones(vocab_size) / vocab_size
|
|
1351
|
+
|
|
1352
|
+
tokens_array = np.array(all_tokens, dtype=np.int32)
|
|
1353
|
+
|
|
1354
|
+
# Calculate histogram (probability distribution)
|
|
1355
|
+
hist, _ = np.histogram(tokens_array,
|
|
1356
|
+
bins=vocab_size,
|
|
1357
|
+
range=(0, vocab_size),
|
|
1358
|
+
density=True)
|
|
1359
|
+
|
|
1360
|
+
# Normalize to ensure it sums to 1
|
|
1361
|
+
hist = hist / (hist.sum() + 1e-8)
|
|
1362
|
+
|
|
1363
|
+
return jnp.array(hist, dtype=jnp.float32)
|
|
1364
|
+
|
|
1365
|
+
def _get_ratio_first_to_last(self, elements: List[float]) -> float:
|
|
1366
|
+
"""Calculate ratio of first to last element in list."""
|
|
1367
|
+
if len(elements) < 2 or elements[-1] == 0:
|
|
1368
|
+
return 1.0
|
|
1369
|
+
return elements[0] / elements[-1]
|
|
1370
|
+
|
|
1371
|
+
|
|
1372
|
+
# ======================== TOP-K AND TOP-P SAMPLING TESTS ========================
|
|
1373
|
+
|
|
1374
|
+
|
|
1375
|
+
class TestTopKTopPSampling:
|
|
1376
|
+
"""Test suite for top-k and top-p sampling with rejection sampler."""
|
|
1377
|
+
|
|
1378
|
+
def _test_masked_logits(
|
|
1379
|
+
self,
|
|
1380
|
+
rejection_sampler: RejectionSampler,
|
|
1381
|
+
batch_size: int,
|
|
1382
|
+
num_draft_tokens: int,
|
|
1383
|
+
vocab_size: int,
|
|
1384
|
+
target_logits: jnp.ndarray,
|
|
1385
|
+
allowed_tokens_per_pos: List[jnp.ndarray],
|
|
1386
|
+
sampling_metadata: TPUSupportedSamplingMetadata,
|
|
1387
|
+
):
|
|
1388
|
+
"""Helper function to test that only allowed tokens are sampled.
|
|
1389
|
+
|
|
1390
|
+
Args:
|
|
1391
|
+
rejection_sampler: The rejection sampler instance
|
|
1392
|
+
batch_size: Number of sequences in the batch
|
|
1393
|
+
num_draft_tokens: Number of draft tokens per sequence
|
|
1394
|
+
vocab_size: Size of vocabulary
|
|
1395
|
+
target_logits: Target logits tensor
|
|
1396
|
+
allowed_tokens_per_pos: List of allowed token arrays for each position
|
|
1397
|
+
sampling_metadata: Sampling metadata with top-k/top-p settings
|
|
1398
|
+
"""
|
|
1399
|
+
num_tokens = batch_size * num_draft_tokens
|
|
1400
|
+
|
|
1401
|
+
# Create random draft probabilities
|
|
1402
|
+
key = jax.random.PRNGKey(42)
|
|
1403
|
+
draft_logits = jax.random.normal(key, (num_tokens, vocab_size))
|
|
1404
|
+
draft_probs = jax.nn.softmax(draft_logits, axis=-1)
|
|
1405
|
+
|
|
1406
|
+
# Randomly sample draft token ids from draft probs
|
|
1407
|
+
draft_key = jax.random.PRNGKey(123)
|
|
1408
|
+
draft_token_ids = jax.random.categorical(draft_key,
|
|
1409
|
+
jnp.log(draft_probs + 1e-8),
|
|
1410
|
+
shape=(num_tokens, ))
|
|
1411
|
+
|
|
1412
|
+
# Prepare inputs
|
|
1413
|
+
num_draft_per_seq = jnp.full((batch_size, ),
|
|
1414
|
+
num_draft_tokens,
|
|
1415
|
+
dtype=jnp.int32)
|
|
1416
|
+
bonus_token_ids = jnp.zeros((batch_size, ), dtype=jnp.int32)
|
|
1417
|
+
|
|
1418
|
+
# Run rejection sampling multiple times to get statistical confidence
|
|
1419
|
+
sample_keys = jax.random.split(jax.random.PRNGKey(456), 10)
|
|
1420
|
+
all_sampled_tokens = []
|
|
1421
|
+
|
|
1422
|
+
for sample_key in sample_keys:
|
|
1423
|
+
output_token_ids = rejection_sampler(
|
|
1424
|
+
draft_token_ids=draft_token_ids,
|
|
1425
|
+
num_draft_tokens=num_draft_per_seq,
|
|
1426
|
+
draft_probs=draft_probs,
|
|
1427
|
+
target_logits=target_logits,
|
|
1428
|
+
bonus_token_ids=bonus_token_ids,
|
|
1429
|
+
sampling_metadata=sampling_metadata,
|
|
1430
|
+
key=sample_key,
|
|
1431
|
+
)
|
|
1432
|
+
|
|
1433
|
+
# Parse output and extract tokens
|
|
1434
|
+
parsed_output = rejection_sampler.parse_output(
|
|
1435
|
+
output_token_ids,
|
|
1436
|
+
vocab_size=vocab_size,
|
|
1437
|
+
num_draft_tokens_cpu=np.asarray(num_draft_per_seq),
|
|
1438
|
+
batch_size=batch_size,
|
|
1439
|
+
padded_tokens_length=num_tokens)
|
|
1440
|
+
|
|
1441
|
+
# For each sequence, check tokens (excluding bonus tokens)
|
|
1442
|
+
for seq_idx, seq_tokens in enumerate(parsed_output):
|
|
1443
|
+
for pos, token_id in enumerate(seq_tokens):
|
|
1444
|
+
if pos < num_draft_tokens: # Only check draft tokens, not bonus
|
|
1445
|
+
token_idx = seq_idx * num_draft_tokens + pos
|
|
1446
|
+
if token_idx < len(allowed_tokens_per_pos):
|
|
1447
|
+
allowed_tokens = allowed_tokens_per_pos[token_idx]
|
|
1448
|
+
all_sampled_tokens.append(
|
|
1449
|
+
(token_idx, token_id, allowed_tokens))
|
|
1450
|
+
|
|
1451
|
+
# Check that all sampled tokens are within allowed sets
|
|
1452
|
+
for token_idx, token_id, allowed_tokens in all_sampled_tokens:
|
|
1453
|
+
assert token_id in allowed_tokens, \
|
|
1454
|
+
f"Token {token_id} at position {token_idx} not in allowed set {allowed_tokens.tolist()}"
|
|
1455
|
+
|
|
1456
|
+
@pytest.mark.parametrize("top_k", [1, 5, 99])
|
|
1457
|
+
def test_top_k(self, rejection_sampler, test_helper, top_k):
|
|
1458
|
+
"""Test rejection sampling with top-k sampling."""
|
|
1459
|
+
vocab_size = 100
|
|
1460
|
+
batch_size = 10
|
|
1461
|
+
num_draft_tokens = 3
|
|
1462
|
+
num_tokens = batch_size * num_draft_tokens
|
|
1463
|
+
|
|
1464
|
+
# Randomly create top-k indices for each token position
|
|
1465
|
+
key = jax.random.PRNGKey(42)
|
|
1466
|
+
top_k_indices = []
|
|
1467
|
+
for i in range(num_tokens):
|
|
1468
|
+
perm_key = jax.random.fold_in(key, i)
|
|
1469
|
+
indices = jax.random.permutation(perm_key, vocab_size)[:top_k]
|
|
1470
|
+
top_k_indices.append(indices)
|
|
1471
|
+
|
|
1472
|
+
# Create target logits with uniform distribution
|
|
1473
|
+
target_logits = jnp.zeros((num_tokens, vocab_size), dtype=jnp.float32)
|
|
1474
|
+
|
|
1475
|
+
# Increment logits for top-k indices slightly to make them more likely
|
|
1476
|
+
# If masking works correctly, only these tokens should be sampled
|
|
1477
|
+
for i in range(num_tokens):
|
|
1478
|
+
indices = top_k_indices[i]
|
|
1479
|
+
target_logits = target_logits.at[i, indices].add(0.1)
|
|
1480
|
+
|
|
1481
|
+
# Create sampling metadata with top-k
|
|
1482
|
+
sampling_metadata = test_helper.create_sampling_metadata(
|
|
1483
|
+
all_greedy=False,
|
|
1484
|
+
batch_size=batch_size,
|
|
1485
|
+
top_k=top_k,
|
|
1486
|
+
top_p=1.0,
|
|
1487
|
+
temperature=1.0,
|
|
1488
|
+
)
|
|
1489
|
+
|
|
1490
|
+
self._test_masked_logits(
|
|
1491
|
+
rejection_sampler=rejection_sampler,
|
|
1492
|
+
batch_size=batch_size,
|
|
1493
|
+
num_draft_tokens=num_draft_tokens,
|
|
1494
|
+
vocab_size=vocab_size,
|
|
1495
|
+
target_logits=target_logits,
|
|
1496
|
+
allowed_tokens_per_pos=top_k_indices,
|
|
1497
|
+
sampling_metadata=sampling_metadata,
|
|
1498
|
+
)
|
|
1499
|
+
|
|
1500
|
+
@pytest.mark.parametrize("top_p", [0.5, 0.9, 0.99])
|
|
1501
|
+
def test_top_p(self, rejection_sampler, test_helper, top_p):
|
|
1502
|
+
"""Test rejection sampling with top-p sampling."""
|
|
1503
|
+
vocab_size = 100
|
|
1504
|
+
batch_size = 10
|
|
1505
|
+
num_draft_tokens = 3
|
|
1506
|
+
num_tokens = batch_size * num_draft_tokens
|
|
1507
|
+
|
|
1508
|
+
# Create random target logits
|
|
1509
|
+
key = jax.random.PRNGKey(42)
|
|
1510
|
+
target_logits = jax.random.normal(key, (num_tokens, vocab_size))
|
|
1511
|
+
|
|
1512
|
+
# Create temperature array for batch
|
|
1513
|
+
temperature = jnp.ones(batch_size, dtype=jnp.float32)
|
|
1514
|
+
|
|
1515
|
+
# Calculate top-p indices for each token position
|
|
1516
|
+
rescaled_logits = target_logits / temperature.repeat(num_draft_tokens,
|
|
1517
|
+
axis=0)[:, None]
|
|
1518
|
+
|
|
1519
|
+
# Sort logits and calculate cumulative probabilities
|
|
1520
|
+
logits_sorted = jnp.sort(rescaled_logits, axis=-1)
|
|
1521
|
+
logits_idx = jnp.argsort(rescaled_logits, axis=-1)
|
|
1522
|
+
probs_sorted = jax.nn.softmax(logits_sorted, axis=-1)
|
|
1523
|
+
probs_cumsum = jnp.cumsum(probs_sorted, axis=-1)
|
|
1524
|
+
|
|
1525
|
+
# Create top-p mask
|
|
1526
|
+
top_p_mask = probs_cumsum <= (1 - top_p)
|
|
1527
|
+
# Ensure at least one token is kept
|
|
1528
|
+
top_p_mask = top_p_mask.at[:, -1].set(False)
|
|
1529
|
+
|
|
1530
|
+
# Get top-p indices for each position
|
|
1531
|
+
top_p_indices = []
|
|
1532
|
+
for i in range(num_tokens):
|
|
1533
|
+
valid_indices = logits_idx[i][~top_p_mask[i]]
|
|
1534
|
+
top_p_indices.append(valid_indices)
|
|
1535
|
+
|
|
1536
|
+
# Create sampling metadata with top-p
|
|
1537
|
+
sampling_metadata = test_helper.create_sampling_metadata(
|
|
1538
|
+
all_greedy=False,
|
|
1539
|
+
batch_size=batch_size,
|
|
1540
|
+
top_k=-1,
|
|
1541
|
+
top_p=top_p,
|
|
1542
|
+
temperature=1.0,
|
|
1543
|
+
)
|
|
1544
|
+
|
|
1545
|
+
self._test_masked_logits(
|
|
1546
|
+
rejection_sampler=rejection_sampler,
|
|
1547
|
+
batch_size=batch_size,
|
|
1548
|
+
num_draft_tokens=num_draft_tokens,
|
|
1549
|
+
vocab_size=vocab_size,
|
|
1550
|
+
target_logits=target_logits,
|
|
1551
|
+
allowed_tokens_per_pos=top_p_indices,
|
|
1552
|
+
sampling_metadata=sampling_metadata,
|
|
1553
|
+
)
|
|
1554
|
+
|
|
1555
|
+
def test_top_k_and_top_p_combined(self, rejection_sampler, test_helper):
|
|
1556
|
+
"""Test rejection sampling with both top-k and top-p applied.
|
|
1557
|
+
|
|
1558
|
+
This test verifies that both top-k and top-p can be used together
|
|
1559
|
+
without errors, but doesn't verify the exact masking behavior since
|
|
1560
|
+
the order of application may vary from our test implementation.
|
|
1561
|
+
"""
|
|
1562
|
+
vocab_size = 50
|
|
1563
|
+
batch_size = 5
|
|
1564
|
+
num_draft_tokens = 2
|
|
1565
|
+
num_tokens = batch_size * num_draft_tokens
|
|
1566
|
+
top_k = 10
|
|
1567
|
+
top_p = 0.8
|
|
1568
|
+
|
|
1569
|
+
# Create random target logits
|
|
1570
|
+
key = jax.random.PRNGKey(123)
|
|
1571
|
+
target_logits = jax.random.normal(key, (num_tokens, vocab_size))
|
|
1572
|
+
|
|
1573
|
+
# Create random draft probabilities
|
|
1574
|
+
draft_key = jax.random.PRNGKey(42)
|
|
1575
|
+
draft_logits = jax.random.normal(draft_key, (num_tokens, vocab_size))
|
|
1576
|
+
draft_probs = jax.nn.softmax(draft_logits, axis=-1)
|
|
1577
|
+
|
|
1578
|
+
# Randomly sample draft token ids from draft probs
|
|
1579
|
+
sample_key = jax.random.PRNGKey(123)
|
|
1580
|
+
draft_token_ids = jax.random.categorical(sample_key,
|
|
1581
|
+
jnp.log(draft_probs + 1e-8),
|
|
1582
|
+
shape=(num_tokens, ))
|
|
1583
|
+
|
|
1584
|
+
# Create sampling metadata with both top-k and top-p
|
|
1585
|
+
sampling_metadata = test_helper.create_sampling_metadata(
|
|
1586
|
+
all_greedy=False,
|
|
1587
|
+
batch_size=batch_size,
|
|
1588
|
+
top_k=top_k,
|
|
1589
|
+
top_p=top_p,
|
|
1590
|
+
temperature=1.0,
|
|
1591
|
+
)
|
|
1592
|
+
|
|
1593
|
+
# Prepare inputs
|
|
1594
|
+
num_draft_per_seq = jnp.full((batch_size, ),
|
|
1595
|
+
num_draft_tokens,
|
|
1596
|
+
dtype=jnp.int32)
|
|
1597
|
+
bonus_token_ids = jnp.zeros((batch_size, ), dtype=jnp.int32)
|
|
1598
|
+
|
|
1599
|
+
# Just test that the combined sampling runs without errors
|
|
1600
|
+
run_key = jax.random.PRNGKey(456)
|
|
1601
|
+
output_token_ids = rejection_sampler(
|
|
1602
|
+
draft_token_ids=draft_token_ids,
|
|
1603
|
+
num_draft_tokens=num_draft_per_seq,
|
|
1604
|
+
draft_probs=draft_probs,
|
|
1605
|
+
target_logits=target_logits,
|
|
1606
|
+
bonus_token_ids=bonus_token_ids,
|
|
1607
|
+
sampling_metadata=sampling_metadata,
|
|
1608
|
+
key=run_key,
|
|
1609
|
+
)
|
|
1610
|
+
|
|
1611
|
+
# Parse output to verify it's well-formed
|
|
1612
|
+
parsed_output = rejection_sampler.parse_output(
|
|
1613
|
+
output_token_ids,
|
|
1614
|
+
vocab_size=vocab_size,
|
|
1615
|
+
num_draft_tokens_cpu=np.asarray(num_draft_per_seq),
|
|
1616
|
+
batch_size=batch_size,
|
|
1617
|
+
padded_tokens_length=num_tokens)
|
|
1618
|
+
|
|
1619
|
+
# Basic sanity checks
|
|
1620
|
+
assert len(parsed_output) == batch_size
|
|
1621
|
+
for seq_tokens in parsed_output:
|
|
1622
|
+
assert len(seq_tokens) >= 0 # Should have at least empty list
|
|
1623
|
+
for token_id in seq_tokens:
|
|
1624
|
+
assert 0 <= token_id < vocab_size # Valid token range
|