tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1624 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Tests for the JAX-based rejection sampler for speculative decoding on TPU.
16
+ This test suite is structured to mirror the GPU rejection sampler tests.
17
+ """
18
+ from dataclasses import dataclass
19
+ from typing import List, Tuple
20
+
21
+ import jax
22
+ import jax.numpy as jnp
23
+ import numpy as np
24
+ import pytest
25
+
26
+ from tpu_inference.layers.jax.sample.rejection_sampler import (
27
+ PLACEHOLDER_TOKEN_ID, RejectionSampler)
28
+ from tpu_inference.layers.jax.sample.sampling_metadata import \
29
+ TPUSupportedSamplingMetadata
30
+
31
+ # ======================== CONSTANTS ========================
32
+
33
+ PAD_TOKEN_ID = -999 # Padding token for draft_token_ids
34
+ VOCAB_SIZE = 128 # Default vocabulary size for tests
35
+ DEFAULT_PADDING_FACTOR = 1.5 # Default padding factor for padded tests
36
+
37
+ # ======================== DATA STRUCTURES ========================
38
+
39
+
40
+ @dataclass
41
+ class RejectionSamplerTestCase:
42
+ """Test case data structure for rejection sampler scenarios."""
43
+ name: str
44
+ draft_tokens: List[int]
45
+ target_tokens: List[int]
46
+ num_draft_per_seq: List[int] # number of draft tokens per sequence
47
+ bonus_tokens: List[int]
48
+ expected: List[List[int]]
49
+ description: str = ""
50
+ use_padding: bool = False # Whether to add padding to draft tokens
51
+
52
+
53
+ # ======================== TEST DATA FACTORY ========================
54
+
55
+
56
+ class TestDataFactory:
57
+ """Factory class for generating test cases."""
58
+
59
+ @staticmethod
60
+ def create_test_case(
61
+ name: str,
62
+ draft_tokens: List[int],
63
+ target_tokens: List[int],
64
+ num_draft_per_seq: List[int],
65
+ bonus_tokens: List[int],
66
+ expected: List[List[int]],
67
+ description: str = "",
68
+ use_padding: bool = False) -> RejectionSamplerTestCase:
69
+ """Create a single test case."""
70
+ return RejectionSamplerTestCase(name=name,
71
+ draft_tokens=draft_tokens,
72
+ target_tokens=target_tokens,
73
+ num_draft_per_seq=num_draft_per_seq,
74
+ bonus_tokens=bonus_tokens,
75
+ expected=expected,
76
+ description=description
77
+ or name.replace("_", " ").title(),
78
+ use_padding=use_padding)
79
+
80
+ @classmethod
81
+ def create_with_padding_variant(
82
+ cls,
83
+ name: str,
84
+ draft_tokens: List[int],
85
+ target_tokens: List[int],
86
+ num_draft_per_seq: List[int],
87
+ bonus_tokens: List[int],
88
+ expected: List[List[int]],
89
+ description: str = "") -> List[RejectionSamplerTestCase]:
90
+ """Create both normal and padded versions of a test case."""
91
+ test_cases = []
92
+
93
+ # Create normal version
94
+ test_cases.append(
95
+ cls.create_test_case(name=name,
96
+ draft_tokens=draft_tokens,
97
+ target_tokens=target_tokens,
98
+ num_draft_per_seq=num_draft_per_seq,
99
+ bonus_tokens=bonus_tokens,
100
+ expected=expected,
101
+ description=description))
102
+
103
+ # Create padded version if there are tokens
104
+ if draft_tokens:
105
+ test_cases.append(
106
+ cls.create_test_case(
107
+ name=f"{name}_padded",
108
+ draft_tokens=draft_tokens,
109
+ target_tokens=target_tokens,
110
+ num_draft_per_seq=num_draft_per_seq,
111
+ bonus_tokens=bonus_tokens,
112
+ expected=expected,
113
+ description=f"{description} (with padding)",
114
+ use_padding=True))
115
+
116
+ return test_cases
117
+
118
+ @classmethod
119
+ def get_basic_test_cases(cls) -> List[RejectionSamplerTestCase]:
120
+ """Generate basic functionality test cases."""
121
+ test_cases = []
122
+
123
+ # Perfect match
124
+ test_cases.extend(
125
+ cls.create_with_padding_variant(
126
+ name="perfect_match",
127
+ draft_tokens=[1, 2, 3],
128
+ target_tokens=[1, 2, 3],
129
+ num_draft_per_seq=[3],
130
+ bonus_tokens=[4],
131
+ expected=[[1, 2, 3, 4]],
132
+ description="Draft tokens perfectly match target argmax"))
133
+
134
+ # Early mismatch
135
+ test_cases.extend(
136
+ cls.create_with_padding_variant(
137
+ name="early_mismatch",
138
+ draft_tokens=[1, 2, 3],
139
+ target_tokens=[1, 5, 3],
140
+ num_draft_per_seq=[3],
141
+ bonus_tokens=[4],
142
+ expected=[[1, 5]],
143
+ description="Mismatch at position 1"))
144
+
145
+ # Multiple sequences
146
+ test_cases.extend(
147
+ cls.create_with_padding_variant(
148
+ name="multiple_sequences",
149
+ draft_tokens=[1, 2, 3, 4],
150
+ target_tokens=[1, 2, 3, 7],
151
+ num_draft_per_seq=[2, 2],
152
+ bonus_tokens=[5, 6],
153
+ expected=[[1, 2, 5], [3, 7]],
154
+ description="Multiple sequences with mixed results"))
155
+
156
+ # Single token sequence
157
+ test_cases.extend(
158
+ cls.create_with_padding_variant(
159
+ name="single_token_sequence",
160
+ draft_tokens=[1],
161
+ target_tokens=[1],
162
+ num_draft_per_seq=[1],
163
+ bonus_tokens=[2],
164
+ expected=[[1, 2]],
165
+ description="Single token sequence with perfect match"))
166
+
167
+ # Empty sequence (no padding variant)
168
+ test_cases.append(
169
+ cls.create_test_case(
170
+ name="empty_sequence",
171
+ draft_tokens=[],
172
+ target_tokens=[],
173
+ num_draft_per_seq=[0],
174
+ bonus_tokens=[5],
175
+ expected=[[5]],
176
+ description="Empty sequence gets bonus token"))
177
+
178
+ return test_cases
179
+
180
+ @classmethod
181
+ def get_variable_length_test_cases(cls) -> List[RejectionSamplerTestCase]:
182
+ """Generate variable length test cases."""
183
+ test_cases = []
184
+
185
+ # Variable length sequences
186
+ test_cases.extend(
187
+ cls.create_with_padding_variant(
188
+ name="variable_length_sequences",
189
+ draft_tokens=[1, 2, 3],
190
+ target_tokens=[1, 5, 3],
191
+ num_draft_per_seq=[2, 1],
192
+ bonus_tokens=[6, 7],
193
+ expected=[[1, 5], [3, 7]],
194
+ description="Sequences with different lengths"))
195
+
196
+ # All different lengths
197
+ test_cases.extend(
198
+ cls.create_with_padding_variant(
199
+ name="all_different_lengths",
200
+ draft_tokens=[1, 2, 3, 4, 5, 6],
201
+ target_tokens=[1, 2, 3, 4, 5, 6],
202
+ num_draft_per_seq=[1, 2, 3],
203
+ bonus_tokens=[7, 9, 10],
204
+ expected=[[1, 7], [2, 3, 9], [4, 5, 6, 10]],
205
+ description="All sequences have different lengths"))
206
+
207
+ # Mixed sequence lengths
208
+ test_cases.extend(
209
+ cls.create_with_padding_variant(
210
+ name="mixed_sequence_lengths",
211
+ draft_tokens=[1, 2, 3, 4, 5],
212
+ target_tokens=[1, 2, 3, 7, 5],
213
+ num_draft_per_seq=[2, 3],
214
+ bonus_tokens=[6, 8],
215
+ expected=[[1, 2, 6], [3, 7]],
216
+ description="Mixed lengths with different outcomes"))
217
+
218
+ return test_cases
219
+
220
+ @classmethod
221
+ def get_edge_case_test_cases(cls) -> List[RejectionSamplerTestCase]:
222
+ """Generate edge case test cases."""
223
+ test_cases = []
224
+
225
+ # Zero length mixed
226
+ test_cases.extend(
227
+ cls.create_with_padding_variant(
228
+ name="zero_length_mixed",
229
+ draft_tokens=[1, 2],
230
+ target_tokens=[1, 2],
231
+ num_draft_per_seq=[0, 2],
232
+ bonus_tokens=[5, 6],
233
+ expected=[[5], [1, 2, 6]],
234
+ description="Zero-length sequence mixed with normal"))
235
+
236
+ # All zero length (no padding variant)
237
+ test_cases.append(
238
+ cls.create_test_case(name="all_zero_length",
239
+ draft_tokens=[],
240
+ target_tokens=[],
241
+ num_draft_per_seq=[0, 0],
242
+ bonus_tokens=[5, 6],
243
+ expected=[[5], [6]],
244
+ description="All sequences are zero-length"))
245
+
246
+ # Immediate rejection
247
+ test_cases.extend(
248
+ cls.create_with_padding_variant(
249
+ name="immediate_rejection",
250
+ draft_tokens=[1, 2, 3, 4, 5, 6],
251
+ target_tokens=[9, 2, 3, 4, 5, 6],
252
+ num_draft_per_seq=[3, 2, 1],
253
+ bonus_tokens=[10, 11, 12],
254
+ expected=[[9], [4, 5, 11], [6, 12]],
255
+ description="Mixed immediate rejection and perfect matches"))
256
+
257
+ # First token mismatch
258
+ test_cases.extend(
259
+ cls.create_with_padding_variant(
260
+ name="first_token_mismatch",
261
+ draft_tokens=[1],
262
+ target_tokens=[2],
263
+ num_draft_per_seq=[1],
264
+ bonus_tokens=[3],
265
+ expected=[[2]],
266
+ description="Single token mismatch"))
267
+
268
+ return test_cases
269
+
270
+ @classmethod
271
+ def get_all_test_cases(cls) -> List[RejectionSamplerTestCase]:
272
+ """Get all test cases including basic, variable length, and edge cases."""
273
+ all_cases = []
274
+ all_cases.extend(cls.get_basic_test_cases())
275
+ all_cases.extend(cls.get_variable_length_test_cases())
276
+ all_cases.extend(cls.get_edge_case_test_cases())
277
+ return all_cases
278
+
279
+
280
+ # ======================== TEST HELPERS ========================
281
+
282
+
283
+ class RejectionSamplerTestHelper:
284
+ """Helper class for rejection sampler tests."""
285
+
286
+ @staticmethod
287
+ def create_target_logits_from_tokens(
288
+ target_token_ids: List[int],
289
+ vocab_size: int = VOCAB_SIZE) -> jnp.ndarray:
290
+ """
291
+ Create target logits that will produce desired token ids on argmax.
292
+
293
+ Args:
294
+ target_token_ids: List of target token IDs
295
+ vocab_size: Size of the vocabulary
296
+
297
+ Returns:
298
+ JAX array of target logits
299
+ """
300
+ num_tokens = len(target_token_ids)
301
+ if num_tokens == 0:
302
+ return jnp.empty((0, vocab_size), dtype=jnp.float32)
303
+
304
+ # Create target logits with low values
305
+ target_logits = jnp.full((num_tokens, vocab_size),
306
+ -100.0,
307
+ dtype=jnp.float32)
308
+
309
+ # Set high values at desired token positions to make them the argmax
310
+ for i, token_id in enumerate(target_token_ids):
311
+ target_logits = target_logits.at[i, token_id].set(100.0)
312
+
313
+ return target_logits
314
+
315
+ @staticmethod
316
+ def create_sampling_metadata(
317
+ all_greedy: bool = True,
318
+ batch_size: int = 1,
319
+ top_k: int = -1,
320
+ top_p: float = 1.0,
321
+ temperature: float = 1.0,
322
+ ) -> TPUSupportedSamplingMetadata:
323
+ """
324
+ Create TPU sampling metadata object.
325
+ """
326
+ return TPUSupportedSamplingMetadata(
327
+ do_sampling=not all_greedy,
328
+ logprobs=False,
329
+ top_k=jnp.full((batch_size, ), top_k, dtype=jnp.int32),
330
+ top_p=jnp.full((batch_size, ), top_p, dtype=jnp.float32),
331
+ temperature=jnp.full((batch_size, ),
332
+ temperature,
333
+ dtype=jnp.float32),
334
+ )
335
+
336
+ @staticmethod
337
+ def create_padded_draft_tokens(
338
+ draft_tokens: List[int],
339
+ padding_factor: float = DEFAULT_PADDING_FACTOR) -> jnp.ndarray:
340
+ """
341
+ Create padded draft tokens array.
342
+
343
+ Args:
344
+ draft_tokens: List of draft tokens
345
+ padding_factor: Factor to determine padding length
346
+
347
+ Returns:
348
+ JAX array of padded tokens
349
+ """
350
+ if not draft_tokens:
351
+ return jnp.array([], dtype=jnp.int32)
352
+
353
+ # Calculate padded length (at least 50% more than actual tokens)
354
+ actual_length = len(draft_tokens)
355
+ padded_length = max(actual_length + 2,
356
+ int(actual_length * padding_factor))
357
+
358
+ # Create padded array
359
+ padded_tokens = [PAD_TOKEN_ID] * padded_length
360
+
361
+ # Copy actual tokens to the beginning
362
+ for i, token in enumerate(draft_tokens):
363
+ padded_tokens[i] = token
364
+
365
+ return jnp.array(padded_tokens, dtype=jnp.int32)
366
+
367
+ @staticmethod
368
+ def prepare_test_inputs(
369
+ test_case: RejectionSamplerTestCase,
370
+ vocab_size: int = VOCAB_SIZE
371
+ ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, int]:
372
+ """
373
+ Prepare inputs for rejection sampler test.
374
+
375
+ Args:
376
+ test_case: Test case with input data
377
+ vocab_size: Vocabulary size
378
+
379
+ Returns:
380
+ Tuple of (draft_token_ids, target_logits, num_draft_tokens,
381
+ bonus_token_ids)
382
+ """
383
+ helper = RejectionSamplerTestHelper()
384
+
385
+ # Prepare draft tokens (with or without padding)
386
+ if test_case.use_padding and test_case.draft_tokens:
387
+ # For padded inputs, simulate how a real system would handle padding
388
+ padded_draft_tokens = helper.create_padded_draft_tokens(
389
+ test_case.draft_tokens)
390
+
391
+ # Extract only the actual tokens
392
+ num_draft_tokens = jnp.array(test_case.num_draft_per_seq,
393
+ dtype=jnp.int32)
394
+ total_actual_tokens = int(jnp.sum(num_draft_tokens))
395
+
396
+ # Extract only the first total_actual_tokens from the padded array
397
+ draft_token_ids = padded_draft_tokens[:total_actual_tokens]
398
+ target_logits = helper.create_target_logits_from_tokens(
399
+ test_case.target_tokens, vocab_size)
400
+ else:
401
+ draft_token_ids = jnp.array(test_case.draft_tokens,
402
+ dtype=jnp.int32)
403
+ target_logits = helper.create_target_logits_from_tokens(
404
+ test_case.target_tokens, vocab_size)
405
+ num_draft_tokens = jnp.array(test_case.num_draft_per_seq,
406
+ dtype=jnp.int32)
407
+
408
+ bonus_token_ids = jnp.array(test_case.bonus_tokens, dtype=jnp.int32)
409
+
410
+ return (draft_token_ids, target_logits, num_draft_tokens,
411
+ bonus_token_ids)
412
+
413
+ @staticmethod
414
+ def run_rejection_sampler_test(
415
+ rejection_sampler: RejectionSampler,
416
+ test_case: RejectionSamplerTestCase,
417
+ vocab_size: int = VOCAB_SIZE,
418
+ ) -> None:
419
+ """
420
+ Run a rejection sampler test from test case data.
421
+
422
+ Args:
423
+ rejection_sampler: RejectionSampler instance
424
+ test_case: Test case to run
425
+ vocab_size: Vocabulary size
426
+ """
427
+ helper = RejectionSamplerTestHelper()
428
+ metadata = helper.create_sampling_metadata(all_greedy=True)
429
+
430
+ # Prepare inputs
431
+ (draft_token_ids, target_logits, num_draft_tokens,
432
+ bonus_token_ids) = helper.prepare_test_inputs(test_case, vocab_size)
433
+
434
+ # Call the rejection sampler
435
+ output = rejection_sampler(
436
+ draft_token_ids=draft_token_ids,
437
+ num_draft_tokens=num_draft_tokens,
438
+ draft_probs=None,
439
+ target_logits=target_logits,
440
+ bonus_token_ids=bonus_token_ids,
441
+ sampling_metadata=metadata,
442
+ )
443
+
444
+ # Parse the output
445
+ parsed_output = rejection_sampler.parse_output(
446
+ output,
447
+ vocab_size=vocab_size,
448
+ num_draft_tokens_cpu=np.asarray(num_draft_tokens),
449
+ batch_size=len(num_draft_tokens),
450
+ padded_tokens_length=int(sum(num_draft_tokens)))
451
+
452
+ assert parsed_output == test_case.expected, \
453
+ f"Test '{test_case.name}': Expected {test_case.expected}, got {parsed_output}"
454
+
455
+
456
+ # ======================== FIXTURES ========================
457
+
458
+
459
+ @pytest.fixture
460
+ def rejection_sampler():
461
+ """Fixture for the RejectionSampler."""
462
+ return RejectionSampler()
463
+
464
+
465
+ @pytest.fixture
466
+ def test_helper():
467
+ """Fixture for the test helper."""
468
+ return RejectionSamplerTestHelper()
469
+
470
+
471
+ @pytest.fixture
472
+ def test_factory():
473
+ """Fixture for the test data factory."""
474
+ return TestDataFactory()
475
+
476
+
477
+ # ======================== TEST CLASSES ========================
478
+
479
+
480
+ class TestRejectionSampler:
481
+ """Comprehensive test suite for rejection sampler."""
482
+
483
+ # =============== Basic Functionality Tests ===============
484
+
485
+ @pytest.mark.parametrize("test_case",
486
+ TestDataFactory.get_all_test_cases(),
487
+ ids=lambda tc: tc.name)
488
+ def test_rejection_sampler_scenarios(self, rejection_sampler, test_case):
489
+ """Test all rejection sampler scenarios including padded versions."""
490
+ RejectionSamplerTestHelper.run_rejection_sampler_test(
491
+ rejection_sampler, test_case)
492
+
493
+ def test_multiple_mismatches(self, rejection_sampler, test_factory):
494
+ """Test handling multiple sequences where both have mismatches."""
495
+ test_cases = test_factory.create_with_padding_variant(
496
+ name="multiple_mismatches",
497
+ draft_tokens=[1, 2, 3, 4, 5, 6],
498
+ target_tokens=[1, 2, 7, 4, 8, 6],
499
+ num_draft_per_seq=[3, 3],
500
+ bonus_tokens=[8, 9],
501
+ expected=[[1, 2, 7], [4, 8]],
502
+ description="Both sequences have mismatches")
503
+
504
+ for test_case in test_cases:
505
+ RejectionSamplerTestHelper.run_rejection_sampler_test(
506
+ rejection_sampler, test_case)
507
+
508
+ # =============== Parse Output Tests ===============
509
+
510
+ def test_parse_output_basic(self, rejection_sampler):
511
+ """Test the parse_output method with basic flattened format."""
512
+ vocab_size = VOCAB_SIZE
513
+
514
+ # Create flattened output: [main_tokens..., bonus_tokens...]
515
+ main_tokens = jnp.array([10, 20, 30, 50, 60], dtype=jnp.int32)
516
+ bonus_tokens = jnp.array([40, 70], dtype=jnp.int32)
517
+ output_token_ids = jnp.concatenate([main_tokens, bonus_tokens])
518
+
519
+ num_draft_tokens = jnp.array([3, 2], dtype=jnp.int32)
520
+
521
+ parsed_output = rejection_sampler.parse_output(
522
+ output_token_ids,
523
+ vocab_size,
524
+ num_draft_tokens_cpu=np.asarray(num_draft_tokens),
525
+ batch_size=len(num_draft_tokens),
526
+ padded_tokens_length=int(sum(num_draft_tokens)))
527
+
528
+ expected = [[10, 20, 30, 40], [50, 60, 70]]
529
+ assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
530
+
531
+ def test_parse_output_with_placeholders(self, rejection_sampler):
532
+ """Test parse_output with rejected tokens (placeholders)."""
533
+ vocab_size = VOCAB_SIZE
534
+
535
+ # Test with rejected tokens (placeholders)
536
+ main_tokens = jnp.array(
537
+ [10, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID, 20, 30],
538
+ dtype=jnp.int32)
539
+ bonus_tokens = jnp.array([PLACEHOLDER_TOKEN_ID, 40], dtype=jnp.int32)
540
+ output_token_ids = jnp.concatenate([main_tokens, bonus_tokens])
541
+
542
+ num_draft_tokens = jnp.array([3, 2], dtype=jnp.int32)
543
+
544
+ parsed_output = rejection_sampler.parse_output(
545
+ output_token_ids,
546
+ vocab_size,
547
+ num_draft_tokens_cpu=np.asarray(num_draft_tokens),
548
+ batch_size=len(num_draft_tokens),
549
+ padded_tokens_length=int(sum(num_draft_tokens)))
550
+
551
+ expected = [[10], [20, 30, 40]]
552
+ assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
553
+
554
+ def test_parse_output_invalid_tokens(self, rejection_sampler):
555
+ """Test parse_output with tokens outside vocab size."""
556
+ vocab_size = VOCAB_SIZE
557
+
558
+ # Test with tokens outside vocab size
559
+ main_tokens = jnp.array([10, vocab_size + 1, 20], dtype=jnp.int32)
560
+ bonus_tokens = jnp.array([vocab_size + 2], dtype=jnp.int32)
561
+ output_token_ids = jnp.concatenate([main_tokens, bonus_tokens])
562
+
563
+ num_draft_tokens = jnp.array([3], dtype=jnp.int32)
564
+
565
+ parsed_output = rejection_sampler.parse_output(
566
+ output_token_ids,
567
+ vocab_size,
568
+ num_draft_tokens_cpu=np.asarray(num_draft_tokens),
569
+ batch_size=len(num_draft_tokens),
570
+ padded_tokens_length=int(sum(num_draft_tokens)))
571
+
572
+ expected = [[10, 20]] # Invalid tokens filtered out
573
+ assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
574
+
575
+ def test_parse_output_empty_sequences(self, rejection_sampler):
576
+ """Test parse_output with empty sequences."""
577
+ vocab_size = VOCAB_SIZE
578
+
579
+ # Test with empty sequences
580
+ main_tokens = jnp.array([], dtype=jnp.int32)
581
+ bonus_tokens = jnp.array([50, 60], dtype=jnp.int32)
582
+ output_token_ids = jnp.concatenate([main_tokens, bonus_tokens])
583
+
584
+ num_draft_tokens = jnp.array([0, 0], dtype=jnp.int32)
585
+
586
+ parsed_output = rejection_sampler.parse_output(
587
+ output_token_ids,
588
+ vocab_size,
589
+ num_draft_tokens_cpu=np.asarray(num_draft_tokens),
590
+ batch_size=len(num_draft_tokens),
591
+ padded_tokens_length=int(sum(num_draft_tokens)))
592
+
593
+ expected = [[50], [60]] # Only bonus tokens
594
+ assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
595
+
596
+ # =============== Padding-Specific Tests ===============
597
+
598
+ def test_padding_ignored_correctly(self, rejection_sampler, test_factory):
599
+ """Test that padding tokens are completely ignored."""
600
+ # Both versions should produce identical results
601
+ test_cases = test_factory.create_with_padding_variant(
602
+ name="padding_test",
603
+ draft_tokens=[1, 2],
604
+ target_tokens=[1, 5],
605
+ num_draft_per_seq=[2],
606
+ bonus_tokens=[3],
607
+ expected=[[1, 5]],
608
+ description="Test padding is ignored")
609
+
610
+ for test_case in test_cases:
611
+ RejectionSamplerTestHelper.run_rejection_sampler_test(
612
+ rejection_sampler, test_case)
613
+
614
+ def test_extreme_padding(self, rejection_sampler, test_helper):
615
+ """Test with extreme padding (much longer than actual tokens)."""
616
+ metadata = test_helper.create_sampling_metadata(all_greedy=True)
617
+
618
+ # Create heavily padded input: [1, 2] + 20 padding tokens
619
+ draft_tokens_with_extreme_padding = [1, 2] + [PAD_TOKEN_ID] * 20
620
+ padded_draft_tokens = jnp.array(draft_tokens_with_extreme_padding,
621
+ dtype=jnp.int32)
622
+
623
+ # Extract only the actual tokens (first 2)
624
+ num_draft_tokens = jnp.array([2], dtype=jnp.int32)
625
+ total_actual_tokens = int(jnp.sum(num_draft_tokens))
626
+ draft_token_ids = padded_draft_tokens[:total_actual_tokens]
627
+
628
+ target_logits = test_helper.create_target_logits_from_tokens(
629
+ [1, 5], VOCAB_SIZE)
630
+ bonus_token_ids = jnp.array([3], dtype=jnp.int32)
631
+
632
+ output = rejection_sampler(
633
+ draft_token_ids=draft_token_ids,
634
+ num_draft_tokens=num_draft_tokens,
635
+ draft_probs=None,
636
+ target_logits=target_logits,
637
+ bonus_token_ids=bonus_token_ids,
638
+ sampling_metadata=metadata,
639
+ )
640
+
641
+ parsed_output = rejection_sampler.parse_output(
642
+ output,
643
+ VOCAB_SIZE,
644
+ num_draft_tokens_cpu=np.asarray(num_draft_tokens),
645
+ batch_size=len(num_draft_tokens),
646
+ padded_tokens_length=int(sum(num_draft_tokens)))
647
+
648
+ expected = [[1, 5]] # Should ignore all padding
649
+ assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
650
+
651
+ def test_realistic_flattened_with_padding(self, rejection_sampler,
652
+ test_factory):
653
+ """Test with realistic flattened input including padding."""
654
+ test_case = test_factory.create_test_case(
655
+ name="realistic_flattened_with_padding",
656
+ draft_tokens=[1, 2, 3],
657
+ target_tokens=[1, 5, 3],
658
+ num_draft_per_seq=[2, 1],
659
+ bonus_tokens=[6, 7],
660
+ expected=[[1, 5], [3, 7]],
661
+ description="Realistic flattened input with padding",
662
+ use_padding=True)
663
+ RejectionSamplerTestHelper.run_rejection_sampler_test(
664
+ rejection_sampler, test_case)
665
+
666
+ # =============== Segment Operation Edge Case Tests ===============
667
+
668
+ def test_all_sequences_immediate_mismatch(self, rejection_sampler,
669
+ test_factory):
670
+ """Test where all sequences have immediate mismatches (first token rejected)."""
671
+ test_cases = test_factory.create_with_padding_variant(
672
+ name="all_immediate_mismatch",
673
+ draft_tokens=[1, 2, 3, 4, 5, 6, 7, 8, 9],
674
+ target_tokens=[10, 2, 3, 11, 5, 6, 12, 8,
675
+ 9], # All first tokens mismatch
676
+ num_draft_per_seq=[3, 3, 3],
677
+ bonus_tokens=[20, 21, 22],
678
+ expected=[[10], [11], [12]], # Only correction tokens, no bonus
679
+ description="All sequences have immediate first token mismatch")
680
+
681
+ for test_case in test_cases:
682
+ RejectionSamplerTestHelper.run_rejection_sampler_test(
683
+ rejection_sampler, test_case)
684
+
685
+ def test_all_sequences_perfect_match(self, rejection_sampler,
686
+ test_factory):
687
+ """Test where all sequences have perfect matches (all tokens accepted)."""
688
+ test_cases = test_factory.create_with_padding_variant(
689
+ name="all_perfect_match",
690
+ draft_tokens=[1, 2, 3, 4, 5, 6, 7, 8, 9],
691
+ target_tokens=[1, 2, 3, 4, 5, 6, 7, 8,
692
+ 9], # All tokens match perfectly
693
+ num_draft_per_seq=[3, 3, 3],
694
+ bonus_tokens=[10, 11, 12],
695
+ expected=[[1, 2, 3, 10], [4, 5, 6, 11],
696
+ [7, 8, 9, 12]], # All accepted + bonus
697
+ description="All sequences have perfect token matches")
698
+
699
+ for test_case in test_cases:
700
+ RejectionSamplerTestHelper.run_rejection_sampler_test(
701
+ rejection_sampler, test_case)
702
+
703
+ def test_extreme_length_imbalance(self, rejection_sampler, test_factory):
704
+ """Test with extreme length imbalance between sequences."""
705
+ # One very long sequence (15 tokens) with others being short (1-2 tokens)
706
+ test_case = test_factory.create_test_case(
707
+ name="extreme_length_imbalance",
708
+ draft_tokens=[
709
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18
710
+ ],
711
+ target_tokens=[
712
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 20, 18
713
+ ],
714
+ num_draft_per_seq=[15, 1, 2], # Very imbalanced lengths
715
+ bonus_tokens=[100, 101, 102],
716
+ expected=[
717
+ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
718
+ 100], # All 15 accepted + bonus
719
+ [16, 101], # Single token accepted + bonus
720
+ [20]
721
+ ], # First token mismatch, no bonus
722
+ description="Extreme length imbalance between sequences")
723
+ RejectionSamplerTestHelper.run_rejection_sampler_test(
724
+ rejection_sampler, test_case)
725
+
726
+ def test_mixed_accept_reject_patterns(self, rejection_sampler,
727
+ test_factory):
728
+ """Test mixed scenarios with perfect matches and immediate rejections."""
729
+ test_cases = test_factory.create_with_padding_variant(
730
+ name="mixed_accept_reject",
731
+ draft_tokens=[1, 2, 3, 4, 5, 6, 7, 8, 9],
732
+ target_tokens=[
733
+ 1, 2, 3, 10, 5, 6, 7, 8, 9
734
+ ], # First: perfect, Second: immediate reject, Third: perfect
735
+ num_draft_per_seq=[3, 3, 3],
736
+ bonus_tokens=[20, 21, 22],
737
+ expected=[[1, 2, 3, 20], [10], [7, 8, 9, 22]], # Mixed results
738
+ description="Mix of perfect matches and immediate rejections")
739
+
740
+ for test_case in test_cases:
741
+ RejectionSamplerTestHelper.run_rejection_sampler_test(
742
+ rejection_sampler, test_case)
743
+
744
+ def test_mismatches_at_same_position(self, rejection_sampler,
745
+ test_factory):
746
+ """Test where mismatches occur at exactly the same position across sequences."""
747
+ test_cases = test_factory.create_with_padding_variant(
748
+ name="same_position_mismatch",
749
+ draft_tokens=[1, 2, 3, 4, 5, 6, 7, 8, 9],
750
+ target_tokens=[1, 10, 3, 4, 11, 6, 7, 12,
751
+ 9], # All mismatch at position 1 (middle token)
752
+ num_draft_per_seq=[3, 3, 3],
753
+ bonus_tokens=[20, 21, 22],
754
+ expected=[[1, 10], [4, 11], [7,
755
+ 12]], # All reject at same position
756
+ description="Mismatches at same position in all sequences")
757
+
758
+ for test_case in test_cases:
759
+ RejectionSamplerTestHelper.run_rejection_sampler_test(
760
+ rejection_sampler, test_case)
761
+
762
+ def test_single_long_sequence(self, rejection_sampler, test_helper):
763
+ """Test a single very long sequence (approaching MAX_SPEC_LEN)."""
764
+ metadata = test_helper.create_sampling_metadata(all_greedy=True)
765
+
766
+ # Create a sequence with 30 draft tokens (close to MAX_SPEC_LEN=32)
767
+ draft_tokens = list(range(1, 31))
768
+ target_tokens = list(range(1, 28)) + [99, 29, 30
769
+ ] # Mismatch at position 27
770
+
771
+ draft_token_ids = jnp.array(draft_tokens, dtype=jnp.int32)
772
+ target_logits = test_helper.create_target_logits_from_tokens(
773
+ target_tokens, VOCAB_SIZE)
774
+ num_draft_tokens = jnp.array([30], dtype=jnp.int32)
775
+ bonus_token_ids = jnp.array([100], dtype=jnp.int32)
776
+
777
+ output = rejection_sampler(
778
+ draft_token_ids=draft_token_ids,
779
+ num_draft_tokens=num_draft_tokens,
780
+ draft_probs=None,
781
+ target_logits=target_logits,
782
+ bonus_token_ids=bonus_token_ids,
783
+ sampling_metadata=metadata,
784
+ )
785
+
786
+ parsed_output = rejection_sampler.parse_output(
787
+ output,
788
+ VOCAB_SIZE,
789
+ num_draft_tokens_cpu=np.asarray(num_draft_tokens),
790
+ batch_size=len(num_draft_tokens),
791
+ padded_tokens_length=int(sum(num_draft_tokens)))
792
+
793
+ expected = [list(range(1, 28)) + [99]] # Tokens up to mismatch point
794
+ assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
795
+
796
+
797
+ # ======================== NON-GREEDY SAMPLING TESTS ========================
798
+
799
+
800
+ class TestNonGreedyRejectionSampler:
801
+ """Test suite for non-greedy (random) rejection sampling."""
802
+
803
+ def test_non_greedy_basic_functionality(self, rejection_sampler,
804
+ test_helper):
805
+ """Test basic non-greedy sampling functionality."""
806
+ metadata = test_helper.create_sampling_metadata(all_greedy=False)
807
+
808
+ # Create simple test case
809
+ draft_tokens = [10, 20, 30]
810
+ target_tokens = [10, 50, 30] # Mismatch at position 1
811
+
812
+ draft_token_ids = jnp.array(draft_tokens, dtype=jnp.int32)
813
+ target_logits = test_helper.create_target_logits_from_tokens(
814
+ target_tokens, VOCAB_SIZE)
815
+
816
+ # Create draft probabilities - make draft tokens highly likely
817
+ draft_probs = jnp.full((len(draft_tokens), VOCAB_SIZE),
818
+ -100.0,
819
+ dtype=jnp.float32)
820
+ for i, token_id in enumerate(draft_tokens):
821
+ draft_probs = draft_probs.at[i, token_id].set(100.0)
822
+
823
+ # Convert logits to probabilities for draft_probs
824
+ draft_probs = jax.nn.softmax(draft_probs, axis=-1)
825
+
826
+ num_draft_tokens = jnp.array([3], dtype=jnp.int32)
827
+ bonus_token_ids = jnp.array([99], dtype=jnp.int32)
828
+ key = jax.random.PRNGKey(42)
829
+
830
+ output = rejection_sampler(
831
+ draft_token_ids=draft_token_ids,
832
+ num_draft_tokens=num_draft_tokens,
833
+ draft_probs=draft_probs,
834
+ target_logits=target_logits,
835
+ bonus_token_ids=bonus_token_ids,
836
+ sampling_metadata=metadata,
837
+ key=key,
838
+ )
839
+
840
+ parsed_output = rejection_sampler.parse_output(
841
+ output,
842
+ VOCAB_SIZE,
843
+ num_draft_tokens_cpu=np.asarray(num_draft_tokens),
844
+ batch_size=1,
845
+ padded_tokens_length=3)
846
+
847
+ # For non-greedy sampling, exact output depends on random sampling
848
+ # but we can check that the first token should be accepted
849
+ assert len(parsed_output) == 1
850
+ assert len(parsed_output[0]) >= 1
851
+ assert parsed_output[0][0] == 10 # First token should match
852
+
853
+ def test_non_greedy_deterministic_with_seed(self, rejection_sampler,
854
+ test_helper):
855
+ """Test that non-greedy sampling is deterministic with the same seed."""
856
+ metadata = test_helper.create_sampling_metadata(all_greedy=False)
857
+
858
+ # Create test case
859
+ draft_tokens = [1, 2, 3, 4]
860
+ target_tokens = [1, 5, 3, 6] # Mismatches at positions 1 and 3
861
+
862
+ draft_token_ids = jnp.array(draft_tokens, dtype=jnp.int32)
863
+ target_logits = test_helper.create_target_logits_from_tokens(
864
+ target_tokens, VOCAB_SIZE)
865
+
866
+ # Create draft probabilities
867
+ draft_probs = jnp.full((len(draft_tokens), VOCAB_SIZE),
868
+ -100.0,
869
+ dtype=jnp.float32)
870
+ for i, token_id in enumerate(draft_tokens):
871
+ draft_probs = draft_probs.at[i, token_id].set(100.0)
872
+
873
+ # Convert logits to probabilities for draft_probs
874
+ draft_probs = jax.nn.softmax(draft_probs, axis=-1)
875
+
876
+ num_draft_tokens = jnp.array([4], dtype=jnp.int32)
877
+ bonus_token_ids = jnp.array([99], dtype=jnp.int32)
878
+
879
+ # Run with same seed multiple times
880
+ key = jax.random.PRNGKey(123)
881
+ outputs = []
882
+
883
+ for _ in range(5):
884
+ output = rejection_sampler(
885
+ draft_token_ids=draft_token_ids,
886
+ num_draft_tokens=num_draft_tokens,
887
+ draft_probs=draft_probs,
888
+ target_logits=target_logits,
889
+ bonus_token_ids=bonus_token_ids,
890
+ sampling_metadata=metadata,
891
+ key=key,
892
+ )
893
+
894
+ parsed_output = rejection_sampler.parse_output(
895
+ output,
896
+ VOCAB_SIZE,
897
+ num_draft_tokens_cpu=np.asarray(num_draft_tokens),
898
+ batch_size=1,
899
+ padded_tokens_length=4)
900
+ outputs.append(parsed_output)
901
+
902
+ # All outputs should be identical with same seed
903
+ for i in range(1, len(outputs)):
904
+ assert outputs[i] == outputs[
905
+ 0], f"Run {i}: {outputs[i]} != {outputs[0]}"
906
+
907
+ def test_non_greedy_with_draft_probs_none(self, rejection_sampler,
908
+ test_helper):
909
+ """Test non-greedy sampling when draft_probs is None."""
910
+ metadata = test_helper.create_sampling_metadata(all_greedy=False)
911
+
912
+ # Create test case
913
+ draft_tokens = [15, 25]
914
+ target_tokens = [15, 35] # Mismatch at position 1
915
+
916
+ draft_token_ids = jnp.array(draft_tokens, dtype=jnp.int32)
917
+ target_logits = test_helper.create_target_logits_from_tokens(
918
+ target_tokens, VOCAB_SIZE)
919
+
920
+ num_draft_tokens = jnp.array([2], dtype=jnp.int32)
921
+ bonus_token_ids = jnp.array([88], dtype=jnp.int32)
922
+ key = jax.random.PRNGKey(777)
923
+
924
+ output = rejection_sampler(
925
+ draft_token_ids=draft_token_ids,
926
+ num_draft_tokens=num_draft_tokens,
927
+ draft_probs=None, # No draft probabilities
928
+ target_logits=target_logits,
929
+ bonus_token_ids=bonus_token_ids,
930
+ sampling_metadata=metadata,
931
+ key=key,
932
+ )
933
+
934
+ parsed_output = rejection_sampler.parse_output(
935
+ output,
936
+ VOCAB_SIZE,
937
+ num_draft_tokens_cpu=np.asarray(num_draft_tokens),
938
+ batch_size=1,
939
+ padded_tokens_length=2)
940
+
941
+ # Should have valid output
942
+ assert len(parsed_output) == 1
943
+ assert len(parsed_output[0]) >= 1
944
+ assert parsed_output[0][0] == 15 # First token should match
945
+
946
+ def test_non_greedy_multiple_sequences(self, rejection_sampler,
947
+ test_helper):
948
+ """Test non-greedy sampling with multiple sequences."""
949
+ metadata = test_helper.create_sampling_metadata(all_greedy=False)
950
+
951
+ # Create test case with 3 sequences
952
+ draft_tokens = [1, 2, 3, 4, 5, 6, 7] # [1,2] [3,4,5] [6,7]
953
+ target_tokens = [1, 5, 3, 8, 5, 6,
954
+ 9] # Mismatches at different positions
955
+
956
+ draft_token_ids = jnp.array(draft_tokens, dtype=jnp.int32)
957
+ target_logits = test_helper.create_target_logits_from_tokens(
958
+ target_tokens, VOCAB_SIZE)
959
+
960
+ # Create draft probabilities
961
+ draft_probs = jnp.full((len(draft_tokens), VOCAB_SIZE),
962
+ -100.0,
963
+ dtype=jnp.float32)
964
+ for i, token_id in enumerate(draft_tokens):
965
+ draft_probs = draft_probs.at[i, token_id].set(100.0)
966
+
967
+ # Convert logits to probabilities for draft_probs
968
+ draft_probs = jax.nn.softmax(draft_probs, axis=-1)
969
+
970
+ num_draft_tokens = jnp.array([2, 3, 2], dtype=jnp.int32)
971
+ bonus_token_ids = jnp.array([11, 12, 13], dtype=jnp.int32)
972
+ key = jax.random.PRNGKey(456)
973
+
974
+ output = rejection_sampler(
975
+ draft_token_ids=draft_token_ids,
976
+ num_draft_tokens=num_draft_tokens,
977
+ draft_probs=draft_probs,
978
+ target_logits=target_logits,
979
+ bonus_token_ids=bonus_token_ids,
980
+ sampling_metadata=metadata,
981
+ key=key,
982
+ )
983
+
984
+ parsed_output = rejection_sampler.parse_output(
985
+ output,
986
+ VOCAB_SIZE,
987
+ num_draft_tokens_cpu=np.asarray(num_draft_tokens),
988
+ batch_size=3,
989
+ padded_tokens_length=7)
990
+
991
+ # Should have 3 sequences
992
+ assert len(parsed_output) == 3
993
+
994
+ # First sequence: [1, 2] -> [1, 5] (mismatch at pos 1)
995
+ assert parsed_output[0][0] == 1
996
+
997
+ # Second sequence: [3, 4, 5] -> [3, 8, 5] (mismatch at pos 1)
998
+ assert parsed_output[1][0] == 3
999
+
1000
+ # Third sequence: [6, 7] -> [6, 9] (mismatch at pos 1)
1001
+ assert parsed_output[2][0] == 6
1002
+
1003
+ def test_non_greedy_with_all_accepted_tokens(self, rejection_sampler,
1004
+ test_helper):
1005
+ """Test non-greedy sampling when all tokens are accepted (perfect match)."""
1006
+ metadata = test_helper.create_sampling_metadata(all_greedy=False)
1007
+
1008
+ # Perfect match case
1009
+ draft_tokens = [10, 20, 30]
1010
+ target_tokens = [10, 20, 30] # Perfect match
1011
+
1012
+ draft_token_ids = jnp.array(draft_tokens, dtype=jnp.int32)
1013
+ target_logits = test_helper.create_target_logits_from_tokens(
1014
+ target_tokens, VOCAB_SIZE)
1015
+
1016
+ # Create draft probabilities - make acceptance very likely
1017
+ draft_probs = jnp.full((len(draft_tokens), VOCAB_SIZE),
1018
+ -100.0,
1019
+ dtype=jnp.float32)
1020
+ for i, token_id in enumerate(draft_tokens):
1021
+ draft_probs = draft_probs.at[i, token_id].set(100.0)
1022
+
1023
+ # Convert logits to probabilities for draft_probs
1024
+ draft_probs = jax.nn.softmax(draft_probs, axis=-1)
1025
+
1026
+ num_draft_tokens = jnp.array([3], dtype=jnp.int32)
1027
+ bonus_token_ids = jnp.array([99], dtype=jnp.int32)
1028
+ key = jax.random.PRNGKey(999)
1029
+
1030
+ output = rejection_sampler(
1031
+ draft_token_ids=draft_token_ids,
1032
+ num_draft_tokens=num_draft_tokens,
1033
+ draft_probs=draft_probs,
1034
+ target_logits=target_logits,
1035
+ bonus_token_ids=bonus_token_ids,
1036
+ sampling_metadata=metadata,
1037
+ key=key,
1038
+ )
1039
+
1040
+ parsed_output = rejection_sampler.parse_output(
1041
+ output,
1042
+ VOCAB_SIZE,
1043
+ num_draft_tokens_cpu=np.asarray(num_draft_tokens),
1044
+ batch_size=1,
1045
+ padded_tokens_length=3)
1046
+
1047
+ # With perfect match and high acceptance probability, should get bonus token
1048
+ assert len(parsed_output) == 1
1049
+ # The exact output depends on random sampling, but should contain the draft tokens
1050
+
1051
+ def test_non_greedy_empty_sequence(self, rejection_sampler, test_helper):
1052
+ """Test non-greedy sampling with empty sequences."""
1053
+ metadata = test_helper.create_sampling_metadata(all_greedy=False)
1054
+
1055
+ # Empty sequences should get bonus tokens
1056
+ draft_token_ids = jnp.array([], dtype=jnp.int32)
1057
+ target_logits = jnp.array([], dtype=jnp.float32).reshape(0, VOCAB_SIZE)
1058
+
1059
+ num_draft_tokens = jnp.array([0, 0], dtype=jnp.int32)
1060
+ bonus_token_ids = jnp.array([77, 88], dtype=jnp.int32)
1061
+ key = jax.random.PRNGKey(333)
1062
+
1063
+ output = rejection_sampler(
1064
+ draft_token_ids=draft_token_ids,
1065
+ num_draft_tokens=num_draft_tokens,
1066
+ draft_probs=None,
1067
+ target_logits=target_logits,
1068
+ bonus_token_ids=bonus_token_ids,
1069
+ sampling_metadata=metadata,
1070
+ key=key,
1071
+ )
1072
+
1073
+ parsed_output = rejection_sampler.parse_output(
1074
+ output,
1075
+ VOCAB_SIZE,
1076
+ num_draft_tokens_cpu=np.asarray(num_draft_tokens),
1077
+ batch_size=2,
1078
+ padded_tokens_length=0)
1079
+
1080
+ # Should get bonus tokens for empty sequences
1081
+ expected = [[77], [88]]
1082
+ assert parsed_output == expected, f"Expected {expected}, got {parsed_output}"
1083
+
1084
+ def test_non_greedy_requires_key(self, rejection_sampler, test_helper):
1085
+ """Test that non-greedy sampling requires a random key."""
1086
+ metadata = test_helper.create_sampling_metadata(all_greedy=False)
1087
+
1088
+ # Create simple test case
1089
+ draft_tokens = [1, 2]
1090
+ target_tokens = [1, 3]
1091
+
1092
+ draft_token_ids = jnp.array(draft_tokens, dtype=jnp.int32)
1093
+ target_logits = test_helper.create_target_logits_from_tokens(
1094
+ target_tokens, VOCAB_SIZE)
1095
+
1096
+ num_draft_tokens = jnp.array([2], dtype=jnp.int32)
1097
+ bonus_token_ids = jnp.array([99], dtype=jnp.int32)
1098
+
1099
+ # Should raise ValueError when key is None for non-greedy sampling
1100
+ with pytest.raises(ValueError, match="A random key must be provided"):
1101
+ rejection_sampler(
1102
+ draft_token_ids=draft_token_ids,
1103
+ num_draft_tokens=num_draft_tokens,
1104
+ draft_probs=None,
1105
+ target_logits=target_logits,
1106
+ bonus_token_ids=bonus_token_ids,
1107
+ sampling_metadata=metadata,
1108
+ key=None, # No key provided
1109
+ )
1110
+
1111
+ def test_non_greedy_vs_greedy_same_perfect_case(self, rejection_sampler,
1112
+ test_helper):
1113
+ """Test that greedy and non-greedy produce same results for perfect matches."""
1114
+ # Perfect match case - both should produce identical results
1115
+ draft_tokens = [5, 15, 25]
1116
+ target_tokens = [5, 15, 25] # Perfect match
1117
+
1118
+ draft_token_ids = jnp.array(draft_tokens, dtype=jnp.int32)
1119
+ target_logits = test_helper.create_target_logits_from_tokens(
1120
+ target_tokens, VOCAB_SIZE)
1121
+
1122
+ # Create draft probabilities
1123
+ draft_probs = jnp.full((len(draft_tokens), VOCAB_SIZE),
1124
+ -100.0,
1125
+ dtype=jnp.float32)
1126
+ for i, token_id in enumerate(draft_tokens):
1127
+ draft_probs = draft_probs.at[i, token_id].set(100.0)
1128
+
1129
+ # Convert logits to probabilities for draft_probs
1130
+ draft_probs = jax.nn.softmax(draft_probs, axis=-1)
1131
+
1132
+ num_draft_tokens = jnp.array([3], dtype=jnp.int32)
1133
+ bonus_token_ids = jnp.array([99], dtype=jnp.int32)
1134
+
1135
+ # Greedy sampling
1136
+ greedy_metadata = test_helper.create_sampling_metadata(all_greedy=True)
1137
+ greedy_output = rejection_sampler(
1138
+ draft_token_ids=draft_token_ids,
1139
+ num_draft_tokens=num_draft_tokens,
1140
+ draft_probs=draft_probs,
1141
+ target_logits=target_logits,
1142
+ bonus_token_ids=bonus_token_ids,
1143
+ sampling_metadata=greedy_metadata,
1144
+ )
1145
+
1146
+ # Non-greedy sampling with high acceptance probability should behave similarly
1147
+ # Note: Due to probabilistic nature, we can't guarantee identical outputs
1148
+ # but for perfect matches with high probabilities, they should be very similar
1149
+ non_greedy_metadata = test_helper.create_sampling_metadata(
1150
+ all_greedy=False)
1151
+ key = jax.random.PRNGKey(555)
1152
+ non_greedy_output = rejection_sampler(
1153
+ draft_token_ids=draft_token_ids,
1154
+ num_draft_tokens=num_draft_tokens,
1155
+ draft_probs=draft_probs,
1156
+ target_logits=target_logits,
1157
+ bonus_token_ids=bonus_token_ids,
1158
+ sampling_metadata=non_greedy_metadata,
1159
+ key=key,
1160
+ )
1161
+
1162
+ # Parse outputs
1163
+ greedy_parsed = rejection_sampler.parse_output(
1164
+ greedy_output, VOCAB_SIZE, np.asarray(num_draft_tokens), 1, 3)
1165
+ non_greedy_parsed = rejection_sampler.parse_output(
1166
+ non_greedy_output, VOCAB_SIZE, np.asarray(num_draft_tokens), 1, 3)
1167
+
1168
+ # For perfect match, greedy should have all tokens + bonus
1169
+ assert greedy_parsed == [[5, 15, 25, 99]]
1170
+
1171
+ # Non-greedy should have valid output (exact content may vary due to sampling)
1172
+ assert len(non_greedy_parsed) == 1
1173
+ assert len(non_greedy_parsed[0]) >= 1
1174
+
1175
+
1176
+ # ======================== STATISTICAL DISTRIBUTION VALIDATION ========================
1177
+
1178
+
1179
+ class TestStatisticalDistributionValidation:
1180
+ """Test suite for validating rejection sampling produces correct probability distributions."""
1181
+
1182
+ def test_rejection_sampling_approximates_target_distribution(self):
1183
+ """Verify rejection sampling approximates target distribution.
1184
+
1185
+ This test validates that rejection sampling produces the correct probability
1186
+ distribution despite sampling from a potentially distinct draft distribution.
1187
+
1188
+ The test works by:
1189
+ 1. Creating random target and draft probability distributions
1190
+ 2. Using rejection sampling to generate token samples
1191
+ 3. Estimating the output distribution from samples
1192
+ 4. Comparing convergence to target vs random reference distributions
1193
+
1194
+ We expect that as sample size increases, the distance to the target
1195
+ distribution decreases much more than the distance to random distributions.
1196
+ """
1197
+
1198
+ vocab_size = 10
1199
+ k = 2
1200
+ num_reference_probs = 100
1201
+
1202
+ # Create random distributions
1203
+ key = jax.random.PRNGKey(42)
1204
+ draft_key, target_key, reference_key = jax.random.split(key, 3)
1205
+
1206
+ # Draft and target distributions
1207
+ draft_logits = jax.random.normal(draft_key, (vocab_size, ))
1208
+ draft_probs = jax.nn.softmax(draft_logits)
1209
+
1210
+ target_logits = jax.random.normal(target_key, (vocab_size, ))
1211
+ target_probs = jax.nn.softmax(target_logits)
1212
+
1213
+ # Reference distributions for comparison
1214
+ reference_logits = jax.random.normal(reference_key,
1215
+ (num_reference_probs, vocab_size))
1216
+ reference_probs = jax.nn.softmax(reference_logits, axis=-1)
1217
+
1218
+ sample_sizes = [10, 100, 1_000, 10_000, 100_000]
1219
+ distance_wrt_reference: List[float] = []
1220
+ distance_wrt_target: List[float] = []
1221
+
1222
+ for num_samples in sample_sizes:
1223
+ # Estimate rejection sampling distribution
1224
+ estimated_probs = self._estimate_rejection_sampling_pdf(
1225
+ draft_probs, target_logits, k, vocab_size, num_samples)
1226
+
1227
+ # Calculate distances
1228
+ reference_vs_rejsample_dist = float(
1229
+ jnp.mean(
1230
+ jnp.linalg.norm(reference_probs - estimated_probs[None, :],
1231
+ axis=-1)))
1232
+ target_vs_rejsample_dist = float(
1233
+ jnp.linalg.norm(target_probs - estimated_probs))
1234
+
1235
+ distance_wrt_reference.append(reference_vs_rejsample_dist)
1236
+ distance_wrt_target.append(target_vs_rejsample_dist)
1237
+
1238
+ print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
1239
+ f"{reference_vs_rejsample_dist=:.05f}")
1240
+
1241
+ # Calculate relative improvements
1242
+ relative_change_target = self._get_ratio_first_to_last(
1243
+ distance_wrt_target)
1244
+ relative_change_reference = self._get_ratio_first_to_last(
1245
+ distance_wrt_reference)
1246
+
1247
+ print(f"Target improvement ratio: {relative_change_target:.02f}")
1248
+ print(f"Reference improvement ratio: {relative_change_reference:.02f}")
1249
+
1250
+ # Validation: Target distribution should converge much better than reference
1251
+ expected_improvement_multiplier = 20
1252
+ assert (relative_change_target >
1253
+ relative_change_reference * expected_improvement_multiplier), \
1254
+ f"Target convergence ({relative_change_target:.2f}) should be " \
1255
+ f"{expected_improvement_multiplier}x better than reference " \
1256
+ f"({relative_change_reference:.2f})"
1257
+
1258
+ def _estimate_rejection_sampling_pdf(
1259
+ self,
1260
+ draft_probs: jnp.ndarray,
1261
+ target_logits: jnp.ndarray,
1262
+ k: int,
1263
+ vocab_size: int,
1264
+ num_samples: int,
1265
+ ) -> jnp.ndarray:
1266
+ """Estimate probability distribution of rejection sampling output.
1267
+
1268
+ Args:
1269
+ draft_probs: Draft probability distribution [vocab_size]
1270
+ target_logits: Target logits [vocab_size]
1271
+ k: Number of draft tokens per sequence
1272
+ vocab_size: Size of vocabulary
1273
+ num_samples: Number of samples to generate
1274
+
1275
+ Returns:
1276
+ Estimated probability distribution [vocab_size]
1277
+ """
1278
+ rejection_sampler = RejectionSampler()
1279
+
1280
+ # Prepare inputs in the flattened format expected by TPU sampler
1281
+ num_tokens = num_samples * k
1282
+
1283
+ # Expand draft probs to match flattened format [num_tokens, vocab_size]
1284
+ draft_probs_expanded = jnp.tile(draft_probs[None, :], (num_tokens, 1))
1285
+
1286
+ # Expand target logits to flattened format
1287
+ target_logits_expanded = jnp.tile(target_logits[None, :],
1288
+ (num_tokens, 1))
1289
+
1290
+ # Generate random draft token ids from draft distribution
1291
+ key = jax.random.PRNGKey(123)
1292
+ draft_tokens_2d = jax.random.categorical(key,
1293
+ jnp.log(draft_probs + 1e-8),
1294
+ shape=(num_samples, k))
1295
+ draft_token_ids = draft_tokens_2d.flatten()
1296
+
1297
+ # Prepare other inputs
1298
+ num_draft_tokens = jnp.full((num_samples, ), k, dtype=jnp.int32)
1299
+ bonus_token_ids = jnp.zeros((num_samples, ),
1300
+ dtype=jnp.int32) # Not used in estimation
1301
+
1302
+ # Create sampling metadata for non-greedy sampling
1303
+ sampling_metadata = TPUSupportedSamplingMetadata(
1304
+ do_sampling=True, # Non-greedy sampling
1305
+ logprobs=False,
1306
+ top_k=jnp.full((num_samples, ), -1, dtype=jnp.int32),
1307
+ top_p=jnp.full((num_samples, ), 1.0, dtype=jnp.float32),
1308
+ temperature=jnp.full((num_samples, ), 1.0, dtype=jnp.float32),
1309
+ )
1310
+
1311
+ # Run rejection sampling
1312
+ sample_key = jax.random.PRNGKey(456)
1313
+ output_token_ids = rejection_sampler(
1314
+ draft_token_ids=draft_token_ids,
1315
+ num_draft_tokens=num_draft_tokens,
1316
+ draft_probs=draft_probs_expanded,
1317
+ target_logits=target_logits_expanded,
1318
+ bonus_token_ids=bonus_token_ids,
1319
+ sampling_metadata=sampling_metadata,
1320
+ key=sample_key,
1321
+ )
1322
+
1323
+ # Parse output and extract main tokens (exclude bonus tokens)
1324
+ parsed_output = rejection_sampler.parse_output(
1325
+ output_token_ids,
1326
+ vocab_size=vocab_size,
1327
+ num_draft_tokens_cpu=np.asarray(num_draft_tokens),
1328
+ batch_size=num_samples,
1329
+ padded_tokens_length=num_tokens)
1330
+
1331
+ # Flatten all main tokens (exclude bonus tokens)
1332
+ all_tokens = []
1333
+ for seq_tokens in parsed_output:
1334
+ if len(seq_tokens) == 0:
1335
+ continue
1336
+ # For rejection sampling, we need to exclude bonus tokens
1337
+ # The bonus token is typically the last one if all draft tokens were accepted
1338
+ # Otherwise, we take all valid tokens up to the rejection point
1339
+ if len(seq_tokens) > k:
1340
+ # More tokens than expected draft tokens means bonus token included
1341
+ main_tokens = seq_tokens[:k]
1342
+ else:
1343
+ # No bonus token, take all tokens
1344
+ main_tokens = seq_tokens
1345
+ all_tokens.extend(main_tokens)
1346
+
1347
+ # Convert to numpy for histogram computation
1348
+ if not all_tokens:
1349
+ # Fallback if no tokens generated
1350
+ return jnp.ones(vocab_size) / vocab_size
1351
+
1352
+ tokens_array = np.array(all_tokens, dtype=np.int32)
1353
+
1354
+ # Calculate histogram (probability distribution)
1355
+ hist, _ = np.histogram(tokens_array,
1356
+ bins=vocab_size,
1357
+ range=(0, vocab_size),
1358
+ density=True)
1359
+
1360
+ # Normalize to ensure it sums to 1
1361
+ hist = hist / (hist.sum() + 1e-8)
1362
+
1363
+ return jnp.array(hist, dtype=jnp.float32)
1364
+
1365
+ def _get_ratio_first_to_last(self, elements: List[float]) -> float:
1366
+ """Calculate ratio of first to last element in list."""
1367
+ if len(elements) < 2 or elements[-1] == 0:
1368
+ return 1.0
1369
+ return elements[0] / elements[-1]
1370
+
1371
+
1372
+ # ======================== TOP-K AND TOP-P SAMPLING TESTS ========================
1373
+
1374
+
1375
+ class TestTopKTopPSampling:
1376
+ """Test suite for top-k and top-p sampling with rejection sampler."""
1377
+
1378
+ def _test_masked_logits(
1379
+ self,
1380
+ rejection_sampler: RejectionSampler,
1381
+ batch_size: int,
1382
+ num_draft_tokens: int,
1383
+ vocab_size: int,
1384
+ target_logits: jnp.ndarray,
1385
+ allowed_tokens_per_pos: List[jnp.ndarray],
1386
+ sampling_metadata: TPUSupportedSamplingMetadata,
1387
+ ):
1388
+ """Helper function to test that only allowed tokens are sampled.
1389
+
1390
+ Args:
1391
+ rejection_sampler: The rejection sampler instance
1392
+ batch_size: Number of sequences in the batch
1393
+ num_draft_tokens: Number of draft tokens per sequence
1394
+ vocab_size: Size of vocabulary
1395
+ target_logits: Target logits tensor
1396
+ allowed_tokens_per_pos: List of allowed token arrays for each position
1397
+ sampling_metadata: Sampling metadata with top-k/top-p settings
1398
+ """
1399
+ num_tokens = batch_size * num_draft_tokens
1400
+
1401
+ # Create random draft probabilities
1402
+ key = jax.random.PRNGKey(42)
1403
+ draft_logits = jax.random.normal(key, (num_tokens, vocab_size))
1404
+ draft_probs = jax.nn.softmax(draft_logits, axis=-1)
1405
+
1406
+ # Randomly sample draft token ids from draft probs
1407
+ draft_key = jax.random.PRNGKey(123)
1408
+ draft_token_ids = jax.random.categorical(draft_key,
1409
+ jnp.log(draft_probs + 1e-8),
1410
+ shape=(num_tokens, ))
1411
+
1412
+ # Prepare inputs
1413
+ num_draft_per_seq = jnp.full((batch_size, ),
1414
+ num_draft_tokens,
1415
+ dtype=jnp.int32)
1416
+ bonus_token_ids = jnp.zeros((batch_size, ), dtype=jnp.int32)
1417
+
1418
+ # Run rejection sampling multiple times to get statistical confidence
1419
+ sample_keys = jax.random.split(jax.random.PRNGKey(456), 10)
1420
+ all_sampled_tokens = []
1421
+
1422
+ for sample_key in sample_keys:
1423
+ output_token_ids = rejection_sampler(
1424
+ draft_token_ids=draft_token_ids,
1425
+ num_draft_tokens=num_draft_per_seq,
1426
+ draft_probs=draft_probs,
1427
+ target_logits=target_logits,
1428
+ bonus_token_ids=bonus_token_ids,
1429
+ sampling_metadata=sampling_metadata,
1430
+ key=sample_key,
1431
+ )
1432
+
1433
+ # Parse output and extract tokens
1434
+ parsed_output = rejection_sampler.parse_output(
1435
+ output_token_ids,
1436
+ vocab_size=vocab_size,
1437
+ num_draft_tokens_cpu=np.asarray(num_draft_per_seq),
1438
+ batch_size=batch_size,
1439
+ padded_tokens_length=num_tokens)
1440
+
1441
+ # For each sequence, check tokens (excluding bonus tokens)
1442
+ for seq_idx, seq_tokens in enumerate(parsed_output):
1443
+ for pos, token_id in enumerate(seq_tokens):
1444
+ if pos < num_draft_tokens: # Only check draft tokens, not bonus
1445
+ token_idx = seq_idx * num_draft_tokens + pos
1446
+ if token_idx < len(allowed_tokens_per_pos):
1447
+ allowed_tokens = allowed_tokens_per_pos[token_idx]
1448
+ all_sampled_tokens.append(
1449
+ (token_idx, token_id, allowed_tokens))
1450
+
1451
+ # Check that all sampled tokens are within allowed sets
1452
+ for token_idx, token_id, allowed_tokens in all_sampled_tokens:
1453
+ assert token_id in allowed_tokens, \
1454
+ f"Token {token_id} at position {token_idx} not in allowed set {allowed_tokens.tolist()}"
1455
+
1456
+ @pytest.mark.parametrize("top_k", [1, 5, 99])
1457
+ def test_top_k(self, rejection_sampler, test_helper, top_k):
1458
+ """Test rejection sampling with top-k sampling."""
1459
+ vocab_size = 100
1460
+ batch_size = 10
1461
+ num_draft_tokens = 3
1462
+ num_tokens = batch_size * num_draft_tokens
1463
+
1464
+ # Randomly create top-k indices for each token position
1465
+ key = jax.random.PRNGKey(42)
1466
+ top_k_indices = []
1467
+ for i in range(num_tokens):
1468
+ perm_key = jax.random.fold_in(key, i)
1469
+ indices = jax.random.permutation(perm_key, vocab_size)[:top_k]
1470
+ top_k_indices.append(indices)
1471
+
1472
+ # Create target logits with uniform distribution
1473
+ target_logits = jnp.zeros((num_tokens, vocab_size), dtype=jnp.float32)
1474
+
1475
+ # Increment logits for top-k indices slightly to make them more likely
1476
+ # If masking works correctly, only these tokens should be sampled
1477
+ for i in range(num_tokens):
1478
+ indices = top_k_indices[i]
1479
+ target_logits = target_logits.at[i, indices].add(0.1)
1480
+
1481
+ # Create sampling metadata with top-k
1482
+ sampling_metadata = test_helper.create_sampling_metadata(
1483
+ all_greedy=False,
1484
+ batch_size=batch_size,
1485
+ top_k=top_k,
1486
+ top_p=1.0,
1487
+ temperature=1.0,
1488
+ )
1489
+
1490
+ self._test_masked_logits(
1491
+ rejection_sampler=rejection_sampler,
1492
+ batch_size=batch_size,
1493
+ num_draft_tokens=num_draft_tokens,
1494
+ vocab_size=vocab_size,
1495
+ target_logits=target_logits,
1496
+ allowed_tokens_per_pos=top_k_indices,
1497
+ sampling_metadata=sampling_metadata,
1498
+ )
1499
+
1500
+ @pytest.mark.parametrize("top_p", [0.5, 0.9, 0.99])
1501
+ def test_top_p(self, rejection_sampler, test_helper, top_p):
1502
+ """Test rejection sampling with top-p sampling."""
1503
+ vocab_size = 100
1504
+ batch_size = 10
1505
+ num_draft_tokens = 3
1506
+ num_tokens = batch_size * num_draft_tokens
1507
+
1508
+ # Create random target logits
1509
+ key = jax.random.PRNGKey(42)
1510
+ target_logits = jax.random.normal(key, (num_tokens, vocab_size))
1511
+
1512
+ # Create temperature array for batch
1513
+ temperature = jnp.ones(batch_size, dtype=jnp.float32)
1514
+
1515
+ # Calculate top-p indices for each token position
1516
+ rescaled_logits = target_logits / temperature.repeat(num_draft_tokens,
1517
+ axis=0)[:, None]
1518
+
1519
+ # Sort logits and calculate cumulative probabilities
1520
+ logits_sorted = jnp.sort(rescaled_logits, axis=-1)
1521
+ logits_idx = jnp.argsort(rescaled_logits, axis=-1)
1522
+ probs_sorted = jax.nn.softmax(logits_sorted, axis=-1)
1523
+ probs_cumsum = jnp.cumsum(probs_sorted, axis=-1)
1524
+
1525
+ # Create top-p mask
1526
+ top_p_mask = probs_cumsum <= (1 - top_p)
1527
+ # Ensure at least one token is kept
1528
+ top_p_mask = top_p_mask.at[:, -1].set(False)
1529
+
1530
+ # Get top-p indices for each position
1531
+ top_p_indices = []
1532
+ for i in range(num_tokens):
1533
+ valid_indices = logits_idx[i][~top_p_mask[i]]
1534
+ top_p_indices.append(valid_indices)
1535
+
1536
+ # Create sampling metadata with top-p
1537
+ sampling_metadata = test_helper.create_sampling_metadata(
1538
+ all_greedy=False,
1539
+ batch_size=batch_size,
1540
+ top_k=-1,
1541
+ top_p=top_p,
1542
+ temperature=1.0,
1543
+ )
1544
+
1545
+ self._test_masked_logits(
1546
+ rejection_sampler=rejection_sampler,
1547
+ batch_size=batch_size,
1548
+ num_draft_tokens=num_draft_tokens,
1549
+ vocab_size=vocab_size,
1550
+ target_logits=target_logits,
1551
+ allowed_tokens_per_pos=top_p_indices,
1552
+ sampling_metadata=sampling_metadata,
1553
+ )
1554
+
1555
+ def test_top_k_and_top_p_combined(self, rejection_sampler, test_helper):
1556
+ """Test rejection sampling with both top-k and top-p applied.
1557
+
1558
+ This test verifies that both top-k and top-p can be used together
1559
+ without errors, but doesn't verify the exact masking behavior since
1560
+ the order of application may vary from our test implementation.
1561
+ """
1562
+ vocab_size = 50
1563
+ batch_size = 5
1564
+ num_draft_tokens = 2
1565
+ num_tokens = batch_size * num_draft_tokens
1566
+ top_k = 10
1567
+ top_p = 0.8
1568
+
1569
+ # Create random target logits
1570
+ key = jax.random.PRNGKey(123)
1571
+ target_logits = jax.random.normal(key, (num_tokens, vocab_size))
1572
+
1573
+ # Create random draft probabilities
1574
+ draft_key = jax.random.PRNGKey(42)
1575
+ draft_logits = jax.random.normal(draft_key, (num_tokens, vocab_size))
1576
+ draft_probs = jax.nn.softmax(draft_logits, axis=-1)
1577
+
1578
+ # Randomly sample draft token ids from draft probs
1579
+ sample_key = jax.random.PRNGKey(123)
1580
+ draft_token_ids = jax.random.categorical(sample_key,
1581
+ jnp.log(draft_probs + 1e-8),
1582
+ shape=(num_tokens, ))
1583
+
1584
+ # Create sampling metadata with both top-k and top-p
1585
+ sampling_metadata = test_helper.create_sampling_metadata(
1586
+ all_greedy=False,
1587
+ batch_size=batch_size,
1588
+ top_k=top_k,
1589
+ top_p=top_p,
1590
+ temperature=1.0,
1591
+ )
1592
+
1593
+ # Prepare inputs
1594
+ num_draft_per_seq = jnp.full((batch_size, ),
1595
+ num_draft_tokens,
1596
+ dtype=jnp.int32)
1597
+ bonus_token_ids = jnp.zeros((batch_size, ), dtype=jnp.int32)
1598
+
1599
+ # Just test that the combined sampling runs without errors
1600
+ run_key = jax.random.PRNGKey(456)
1601
+ output_token_ids = rejection_sampler(
1602
+ draft_token_ids=draft_token_ids,
1603
+ num_draft_tokens=num_draft_per_seq,
1604
+ draft_probs=draft_probs,
1605
+ target_logits=target_logits,
1606
+ bonus_token_ids=bonus_token_ids,
1607
+ sampling_metadata=sampling_metadata,
1608
+ key=run_key,
1609
+ )
1610
+
1611
+ # Parse output to verify it's well-formed
1612
+ parsed_output = rejection_sampler.parse_output(
1613
+ output_token_ids,
1614
+ vocab_size=vocab_size,
1615
+ num_draft_tokens_cpu=np.asarray(num_draft_per_seq),
1616
+ batch_size=batch_size,
1617
+ padded_tokens_length=num_tokens)
1618
+
1619
+ # Basic sanity checks
1620
+ assert len(parsed_output) == batch_size
1621
+ for seq_tokens in parsed_output:
1622
+ assert len(seq_tokens) >= 0 # Should have at least empty list
1623
+ for token_id in seq_tokens:
1624
+ assert 0 <= token_id < vocab_size # Valid token range