tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,368 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from unittest.mock import MagicMock, patch
16
+
17
+ import jax
18
+ import numpy as np
19
+ import pytest
20
+ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
21
+ SchedulerConfig, SpeculativeConfig, VllmConfig)
22
+ from vllm.sampling_params import SamplingType
23
+ from vllm.v1.outputs import DraftTokenIds
24
+
25
+ from tpu_inference.runner.input_batch import CachedRequestState, InputBatch
26
+ from tpu_inference.runner.speculative_decoding_manager import \
27
+ SpecDecodeMetadata
28
+ from tpu_inference.runner.tpu_runner import TPUModelRunner
29
+ from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
30
+
31
+
32
+ class TestSpeculativeDecodingManager:
33
+
34
+ def setup_method(self):
35
+ # Mock JAX dependencies
36
+ self.mock_devices = [MagicMock(coords=i) for i in range(1)]
37
+ device_array = np.array(jax.devices()[:1]).reshape(1, 1, 1, 1)
38
+ self.mock_mesh = jax.make_mesh(device_array.shape,
39
+ ('data', 'attn_dp', 'expert', 'model'))
40
+ self.mock_rng_key = MagicMock()
41
+
42
+ with patch('jax.devices', return_value=self.mock_devices), \
43
+ patch('jax.make_mesh', return_value=self.mock_mesh), \
44
+ patch('jax.random.key', return_value=self.mock_rng_key), \
45
+ patch('tpu_inference.runner.tpu_runner.get_model', return_value=MagicMock()), \
46
+ patch('tpu_inference.runner.tpu_runner.make_optimized_mesh', return_value=self.mock_mesh):
47
+
48
+ model_config = ModelConfig(tokenizer_mode="auto",
49
+ trust_remote_code=False,
50
+ seed=0,
51
+ dtype='bfloat16')
52
+ cache_config = CacheConfig(
53
+ block_size=16,
54
+ gpu_memory_utilization=0.9,
55
+ swap_space=4,
56
+ cache_dtype="auto",
57
+ )
58
+ scheduler_config = SchedulerConfig(max_num_seqs=16,
59
+ max_model_len=1024,
60
+ is_encoder_decoder=False)
61
+ parallel_config = ParallelConfig(
62
+ pipeline_parallel_size=1,
63
+ tensor_parallel_size=1,
64
+ worker_use_ray=False,
65
+ )
66
+ speculative_config = SpeculativeConfig(
67
+ model='ngram',
68
+ num_speculative_tokens=5,
69
+ prompt_lookup_max=4,
70
+ )
71
+ vllm_config = VllmConfig(
72
+ model_config=model_config,
73
+ cache_config=cache_config,
74
+ scheduler_config=scheduler_config,
75
+ parallel_config=parallel_config,
76
+ speculative_config=speculative_config,
77
+ observability_config={},
78
+ additional_config={},
79
+ )
80
+
81
+ self.runner = TPUModelRunner(vllm_config,
82
+ devices=self.mock_devices)
83
+
84
+ def test_propose_draft_token_ids_dispatches_to_eagle(self):
85
+ """Tests that propose_draft_token_ids calls the correct eagle method."""
86
+ # 1. ===== Setup =====
87
+ # Set the drafter to be an Eagle3Proposer
88
+ self.runner.drafter = MagicMock(spec=Eagle3Proposer)
89
+ self.runner.speculative_config.method = "eagle3"
90
+
91
+ # Mock the eagle-specific proposal method
92
+ with patch.object(self.runner.speculative_decoding_manager,
93
+ 'propose_eagle3_draft_token_ids',
94
+ return_value=[[10, 11]]) as mock_propose_eagle:
95
+
96
+ # 2. ===== Act =====
97
+ self.runner.speculative_decoding_manager.propose_draft_token_ids(
98
+ sampled_token_ids=[[1]],
99
+ aux_hidden_states=None,
100
+ attn_metadata=MagicMock(),
101
+ spec_decode_metadata=None,
102
+ )
103
+
104
+ # 3. ===== Assert =====
105
+ mock_propose_eagle.assert_called_once()
106
+ assert self.runner.speculative_decoding_manager._draft_token_ids == [
107
+ [10, 11]
108
+ ]
109
+
110
+ def test_propose_draft_token_ids_wrong_drafter_type(self):
111
+ """Tests that an assertion is raised if the drafter is not an NgramProposer."""
112
+ # The default drafter is NgramProposer, so we replace it with a generic mock
113
+ self.runner.drafter = MagicMock()
114
+ self.runner.speculative_config.method = "ngram"
115
+ with pytest.raises(AssertionError):
116
+ self.runner.speculative_decoding_manager.propose_draft_token_ids(
117
+ [[1]], None, MagicMock(), None)
118
+
119
+ def test_take_draft_token_ids(self):
120
+ """Tests the take_draft_token_ids method for speculative decoding."""
121
+ # Case 1: No draft tokens are available.
122
+ self.runner.speculative_decoding_manager._draft_token_ids = None
123
+ result = self.runner.take_draft_token_ids()
124
+ assert result is None
125
+
126
+ # Case 2: Draft tokens are available.
127
+ mock_req_ids = ["req-1", "req-2"]
128
+ mock_draft_ids = [[10, 11], [20, 21, 22]]
129
+
130
+ # Re-initialize input_batch for a clean state for this specific test
131
+ self.runner.input_batch = InputBatch(
132
+ max_num_reqs=self.runner.max_num_reqs,
133
+ max_model_len=self.runner.max_model_len,
134
+ max_num_batched_tokens=self.runner.max_num_tokens,
135
+ pin_memory=False,
136
+ vocab_size=self.runner.vocab_size,
137
+ block_sizes=[self.runner.block_size],
138
+ is_spec_decode=True,
139
+ )
140
+
141
+ # Add some requests to populate `input_batch.req_ids`
142
+ mock_sampling_params = MagicMock()
143
+ mock_sampling_params.sampling_type = SamplingType.GREEDY
144
+ mock_sampling_params.top_k = -1
145
+ mock_sampling_params.top_p = 1.0
146
+ mock_sampling_params.temperature = 0.0
147
+ mock_sampling_params.min_tokens = 0
148
+ mock_sampling_params.logprobs = None
149
+ mock_sampling_params.logit_bias = None
150
+ mock_sampling_params.allowed_token_ids = set()
151
+ mock_sampling_params.bad_words_token_ids = None
152
+ mock_sampling_params.all_stop_token_ids = set()
153
+
154
+ req1 = CachedRequestState(req_id="req-1",
155
+ prompt_token_ids=[1],
156
+ output_token_ids=[],
157
+ sampling_params=mock_sampling_params,
158
+ block_ids=([1], ),
159
+ num_computed_tokens=1,
160
+ lora_request=None,
161
+ mm_features=[],
162
+ pooling_params=None,
163
+ generator=None)
164
+ req2 = CachedRequestState(req_id="req-2",
165
+ prompt_token_ids=[2],
166
+ output_token_ids=[],
167
+ sampling_params=mock_sampling_params,
168
+ block_ids=([2], ),
169
+ num_computed_tokens=1,
170
+ lora_request=None,
171
+ mm_features=[],
172
+ pooling_params=None,
173
+ generator=None)
174
+ self.runner.input_batch.add_request(req1)
175
+ self.runner.input_batch.add_request(req2)
176
+
177
+ # Set the draft tokens to be taken
178
+ self.runner.speculative_decoding_manager._draft_token_ids = mock_draft_ids
179
+
180
+ # Call the method to be tested
181
+ result = self.runner.take_draft_token_ids()
182
+
183
+ # Assertions for the returned object
184
+ assert result is not None
185
+ assert isinstance(result, DraftTokenIds)
186
+ assert result.req_ids == mock_req_ids
187
+ assert result.draft_token_ids == mock_draft_ids
188
+
189
+ # Assert that the internal state is reset
190
+ assert self.runner.speculative_decoding_manager._draft_token_ids is None
191
+
192
+ # Case 3: Call again after taking, should return None
193
+ result_after = self.runner.take_draft_token_ids()
194
+ assert result_after is None
195
+
196
+ def _setup_spec_decode_metadata_test(self):
197
+ """Helper method to set up common test infrastructure for spec decode metadata tests."""
198
+ # Mock runner attributes needed by the function
199
+ self.runner.arange_cpu = np.arange(1024, dtype=np.int64)
200
+ # Make input_ids_cpu a sequence of numbers for easy verification
201
+ self.runner.input_ids_cpu = np.arange(1024, dtype=np.int32) * 10
202
+ self.runner.num_tokens_paddings = [16, 32, 64, 128, 256, 512, 1024]
203
+
204
+ # Mock the device_array function to just return the numpy arrays
205
+ def mock_device_array(mesh, *args, **kwargs):
206
+ # Skip mesh parameter and return the actual arrays
207
+ if len(args) == 1 and isinstance(args[0], tuple):
208
+ return args[0]
209
+ return args
210
+
211
+ self.mock_device_array = mock_device_array
212
+
213
+ @pytest.mark.parametrize(
214
+ "num_draft_tokens,cu_num_scheduled_tokens,padded_num_reqs,expected_logits_indices,expected_bonus_logits_indices,expected_target_logits_indices,expected_draft_token_ids",
215
+ [
216
+ (
217
+ # Normal case
218
+ [3, 0, 2, 0, 1],
219
+ [4, 104, 107, 207, 209],
220
+ 8,
221
+ [0, 1, 2, 3, 103, 104, 105, 106, 206, 207, 208],
222
+ [3, 4, 7, 8, 10, 0, 0, 0],
223
+ [0, 1, 2, 5, 6, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
224
+ [10, 20, 30, 1050, 1060, 2080, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
225
+ (
226
+ # High speculative tokens case
227
+ [5, 3, 4, 2, 1],
228
+ [6, 10, 18, 22, 26],
229
+ 8,
230
+ [
231
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 13, 14, 15, 16, 17, 19, 20,
232
+ 21, 24, 25
233
+ ],
234
+ [5, 9, 14, 17, 19, 0, 0, 0],
235
+ [
236
+ 0, 1, 2, 3, 4, 6, 7, 8, 10, 11, 12, 13, 15, 16, 18, 0, 0,
237
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
238
+ ],
239
+ [
240
+ 10, 20, 30, 40, 50, 70, 80, 90, 140, 150, 160, 170, 200,
241
+ 210, 250
242
+ ]),
243
+ ])
244
+ def test_get_spec_decode_metadata_parametrized(
245
+ self, num_draft_tokens, cu_num_scheduled_tokens, padded_num_reqs,
246
+ expected_logits_indices, expected_bonus_logits_indices,
247
+ expected_target_logits_indices, expected_draft_token_ids):
248
+ """Comprehensive parametrized test for _get_spec_decode_metadata function."""
249
+ # Setup
250
+ self._setup_spec_decode_metadata_test()
251
+
252
+ # Convert Python lists to numpy arrays for function input
253
+ num_draft_tokens_np = np.array(num_draft_tokens, dtype=np.int32)
254
+ cu_num_scheduled_tokens_np = np.array(cu_num_scheduled_tokens,
255
+ dtype=np.int32)
256
+
257
+ # Act
258
+ with patch(
259
+ "tpu_inference.runner.speculative_decoding_manager.device_array",
260
+ side_effect=self.mock_device_array):
261
+ metadata = self.runner.speculative_decoding_manager.get_spec_decode_metadata(
262
+ num_draft_tokens_np,
263
+ cu_num_scheduled_tokens_np,
264
+ padded_num_reqs=padded_num_reqs)
265
+
266
+ # Assert basic properties
267
+ assert isinstance(metadata, SpecDecodeMetadata)
268
+
269
+ # Determine padding length based on expected_logits_indices length
270
+ if len(expected_logits_indices) <= 16:
271
+ padded_len = 16
272
+ else:
273
+ padded_len = 32
274
+
275
+ # final_logits_indices - pad to bucket size and compare as Python lists
276
+ expected_padded_logits_indices = expected_logits_indices + [0] * (
277
+ padded_len - len(expected_logits_indices))
278
+ assert np.asarray(metadata.final_logits_indices).tolist(
279
+ ) == expected_padded_logits_indices
280
+
281
+ # bonus_logits_indices - compare as Python lists
282
+ assert np.asarray(metadata.bonus_logits_indices).tolist(
283
+ ) == expected_bonus_logits_indices
284
+
285
+ # target_logits_indices - pad to same length as final_logits_indices and compare as Python lists
286
+ expected_padded_target_logits_indices = expected_target_logits_indices + [
287
+ 0
288
+ ] * (padded_len - len(expected_target_logits_indices))
289
+ assert np.asarray(metadata.target_logits_indices).tolist(
290
+ ) == expected_padded_target_logits_indices
291
+
292
+ # draft_token_ids - pad the expected values to the correct length and compare as Python lists
293
+ expected_padded_draft_token_ids = expected_draft_token_ids + [0] * (
294
+ padded_len - len(expected_draft_token_ids))
295
+ assert np.asarray(metadata.draft_token_ids).tolist(
296
+ ) == expected_padded_draft_token_ids
297
+
298
+ # draft_lengths - pad and compare as Python lists
299
+ expected_padded_num_draft_tokens = num_draft_tokens + [0] * (
300
+ padded_num_reqs - len(num_draft_tokens))
301
+ assert np.asarray(metadata.draft_lengths).tolist(
302
+ ) == expected_padded_num_draft_tokens
303
+
304
+ @pytest.mark.parametrize("spec_decode_metadata_is_none", [True, False])
305
+ def test_propose_eagle3_draft_token_ids(self,
306
+ spec_decode_metadata_is_none):
307
+ """Tests the logic for proposing Eagle3 draft tokens."""
308
+ # 1. ===== Setup =====
309
+ self.runner.drafter = MagicMock(spec=Eagle3Proposer)
310
+ self.runner.speculative_config.method = "eagle3"
311
+
312
+ # Mock TPUModelRunner attributes
313
+ self.runner.input_batch = MagicMock()
314
+ self.runner.input_batch.req_ids = ["req-1", "req-2"]
315
+ self.runner.requests = {
316
+ "req-1": MagicMock(),
317
+ "req-2": MagicMock(),
318
+ }
319
+ self.runner.mesh = self.mock_mesh
320
+ self.runner.kv_caches = MagicMock()
321
+
322
+ # Mock drafter methods
323
+ mock_attn_metadata = MagicMock()
324
+ mock_target_token_ids = MagicMock()
325
+ mock_last_token_indices = MagicMock()
326
+ mock_target_hidden_states = MagicMock()
327
+ self.runner.drafter.prepare_inputs.return_value = (
328
+ mock_target_hidden_states,
329
+ mock_target_token_ids,
330
+ mock_last_token_indices,
331
+ mock_attn_metadata,
332
+ )
333
+ mock_draft_token_ids = [[10, 11], [20, 21]]
334
+ self.runner.drafter.propose.return_value = (
335
+ self.runner.kv_caches,
336
+ mock_draft_token_ids,
337
+ )
338
+
339
+ # Inputs
340
+ sampled_token_ids = [[1], [2]]
341
+ aux_hidden_states = MagicMock()
342
+ attn_metadata = MagicMock()
343
+ attn_metadata.seq_lens.shape = [2]
344
+ if spec_decode_metadata_is_none:
345
+ spec_decode_metadata = None
346
+ else:
347
+ spec_decode_metadata = MagicMock(spec=SpecDecodeMetadata)
348
+ spec_decode_metadata.draft_lengths_cpu = np.array([2, 3])
349
+ scheduler_output = MagicMock()
350
+ input_ids = MagicMock()
351
+
352
+ # 2. ===== Act =====
353
+ with patch(
354
+ "tpu_inference.runner.speculative_decoding_manager.device_array",
355
+ side_effect=lambda mesh, x: x):
356
+ result = self.runner.speculative_decoding_manager.propose_eagle3_draft_token_ids(
357
+ sampled_token_ids,
358
+ aux_hidden_states,
359
+ attn_metadata,
360
+ spec_decode_metadata,
361
+ scheduler_output,
362
+ input_ids,
363
+ )
364
+
365
+ # 3. ===== Assert =====
366
+ assert result == [[10, 11], [20, 21]]
367
+ self.runner.drafter.prepare_inputs.assert_called_once()
368
+ self.runner.drafter.propose.assert_called_once()
@@ -0,0 +1,220 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from unittest.mock import MagicMock, patch
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import numpy as np
20
+ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
21
+ SchedulerConfig, SpeculativeConfig, VllmConfig)
22
+ from vllm.sampling_params import SamplingType
23
+
24
+ from tpu_inference.runner.input_batch import CachedRequestState
25
+ from tpu_inference.runner.tpu_runner import TPUModelRunner
26
+
27
+
28
+ class TestStructuredDecodingManager:
29
+
30
+ def setup_method(self):
31
+ # Mock JAX dependencies
32
+ self.mock_rng_key = MagicMock()
33
+ self.mock_devices = [MagicMock(coords=i) for i in range(1)]
34
+ device_array = np.array(jax.devices()[:1]).reshape(1, 1, 1, 1)
35
+ self.mock_mesh = jax.make_mesh(device_array.shape,
36
+ ('data', 'attn_dp', 'expert', 'model'))
37
+ self.mock_rng_key = MagicMock()
38
+
39
+
40
+ with patch('jax.devices', return_value=self.mock_devices), \
41
+ patch('jax.make_mesh', return_value=self.mock_mesh), \
42
+ patch('jax.random.key', return_value=self.mock_rng_key), \
43
+ patch('tpu_inference.runner.tpu_runner.get_model', return_value=MagicMock()), \
44
+ patch('tpu_inference.runner.tpu_runner.make_optimized_mesh', return_value=self.mock_mesh):
45
+
46
+ model_config = ModelConfig(tokenizer_mode="auto",
47
+ trust_remote_code=False,
48
+ seed=0,
49
+ dtype='bfloat16')
50
+ cache_config = CacheConfig(
51
+ block_size=16,
52
+ gpu_memory_utilization=0.9,
53
+ swap_space=4,
54
+ cache_dtype="auto",
55
+ )
56
+ scheduler_config = SchedulerConfig(max_num_seqs=16,
57
+ max_model_len=1024,
58
+ is_encoder_decoder=False)
59
+ parallel_config = ParallelConfig(
60
+ pipeline_parallel_size=1,
61
+ tensor_parallel_size=1,
62
+ worker_use_ray=False,
63
+ )
64
+ speculative_config = SpeculativeConfig(
65
+ model='ngram',
66
+ num_speculative_tokens=5,
67
+ prompt_lookup_max=4,
68
+ )
69
+ vllm_config = VllmConfig(
70
+ model_config=model_config,
71
+ cache_config=cache_config,
72
+ scheduler_config=scheduler_config,
73
+ parallel_config=parallel_config,
74
+ speculative_config=speculative_config,
75
+ observability_config={},
76
+ additional_config={},
77
+ )
78
+
79
+ self.runner = TPUModelRunner(vllm_config,
80
+ devices=self.mock_devices)
81
+
82
+ def test_structured_decoding(self):
83
+ # 1. ===== Setup =====
84
+ # Configure runner for the test
85
+ self.runner.model_config.get_vocab_size = MagicMock(return_value=64)
86
+ self.runner._init_inputs() # re-initialize with new vocab size
87
+
88
+ # Mock device_array to avoid JAX sharding issues with MagicMock mesh
89
+ def mock_device_array(mesh, *args, sharding=None, **kwargs):
90
+ # Simply return the arguments without any sharding (skip mesh parameter)
91
+ if len(args) == 1 and isinstance(args[0], tuple):
92
+ return args[0] # Return tuple as is
93
+ elif len(args) == 1:
94
+ return args[0] # Return single array as is
95
+ else:
96
+ return args # Return all arguments as tuple
97
+
98
+ # Patch the centralized device_array function instead of runner's method
99
+ with patch(
100
+ 'tpu_inference.runner.structured_decoding_manager.device_array',
101
+ side_effect=mock_device_array):
102
+
103
+ # Create a mock for sampling_params to avoid TypeErrors in add_request
104
+ mock_sampling_params = MagicMock()
105
+ mock_sampling_params.sampling_type = SamplingType.GREEDY
106
+ mock_sampling_params.temperature = 0.0
107
+ mock_sampling_params.top_p = 1.0
108
+ mock_sampling_params.top_k = -1
109
+ mock_sampling_params.min_tokens = 0
110
+ mock_sampling_params.logprobs = None
111
+ mock_sampling_params.logit_bias = None
112
+ mock_sampling_params.allowed_token_ids = set()
113
+ mock_sampling_params.bad_words_token_ids = None
114
+ mock_sampling_params.all_stop_token_ids = set()
115
+
116
+ # Add requests to the input batch
117
+ req1 = CachedRequestState(
118
+ req_id="req-1",
119
+ prompt_token_ids=[1],
120
+ output_token_ids=[],
121
+ sampling_params=mock_sampling_params,
122
+ block_ids=([1], ),
123
+ num_computed_tokens=1,
124
+ lora_request=None,
125
+ mm_features=[],
126
+ pooling_params=None,
127
+ generator=None,
128
+ )
129
+ req2 = CachedRequestState(
130
+ req_id="req-2",
131
+ prompt_token_ids=[2],
132
+ output_token_ids=[],
133
+ sampling_params=mock_sampling_params,
134
+ block_ids=([2], ),
135
+ num_computed_tokens=1,
136
+ lora_request=None,
137
+ mm_features=[],
138
+ pooling_params=None,
139
+ generator=None,
140
+ )
141
+ req3 = CachedRequestState(
142
+ req_id="req-3",
143
+ prompt_token_ids=[3],
144
+ output_token_ids=[],
145
+ sampling_params=mock_sampling_params,
146
+ block_ids=([3], ),
147
+ num_computed_tokens=1,
148
+ lora_request=None,
149
+ mm_features=[],
150
+ pooling_params=None,
151
+ generator=None,
152
+ )
153
+ self.runner.input_batch.add_request(req1) # index 0
154
+ self.runner.input_batch.add_request(req2) # index 1
155
+ self.runner.input_batch.add_request(req3) # index 2
156
+ num_reqs = 3
157
+
158
+ # Mock scheduler output for structured decoding
159
+ # req-1 and req-3 require structured decoding
160
+ mock_scheduler_output = MagicMock()
161
+ mock_scheduler_output.structured_output_request_ids = {
162
+ "req-1": 0, # maps req_id to index in grammar_bitmask
163
+ "req-3": 1,
164
+ }
165
+ # Bitmask: vocab_size=64, so 2 int32s per request
166
+ # Mask for req-1: allow tokens 0-31
167
+ mask1 = np.array([-1, 0], dtype=np.int32)
168
+ # Mask for req-3: allow tokens 32-63
169
+ mask2 = np.array([0, -1], dtype=np.int32)
170
+ mock_scheduler_output.grammar_bitmask = np.array([mask1, mask2])
171
+
172
+ # Mock logits
173
+ logits_shape = (num_reqs, self.runner.vocab_size)
174
+ mock_logits_device = jnp.ones(logits_shape, dtype=jnp.bfloat16)
175
+
176
+ # 2. ===== Test prepare_structured_decoding_input =====
177
+ (
178
+ require_struct_decoding, grammar_bitmask, arange
179
+ ) = self.runner.structured_decoding_manager.prepare_structured_decoding_input(
180
+ mock_logits_device, mock_scheduler_output)
181
+
182
+ # Assertions for prepare_structured_decoding_input
183
+ # require_structured_out_cpu should be [True, False, True]
184
+ # because req-1 is at batch index 0, req-2 at 1, req-3 at 2
185
+ expected_require_struct = np.array([[True], [False], [True]],
186
+ dtype=np.bool_)
187
+ np.testing.assert_array_equal(np.array(require_struct_decoding),
188
+ expected_require_struct)
189
+
190
+ # grammar_bitmask_cpu should have mask1 at index 0, mask2 at index 2
191
+ expected_grammar_bitmask = np.zeros_like(
192
+ self.runner.grammar_bitmask_cpu[:num_reqs])
193
+ expected_grammar_bitmask[0] = mask1
194
+ expected_grammar_bitmask[2] = mask2
195
+ np.testing.assert_array_equal(np.array(grammar_bitmask),
196
+ expected_grammar_bitmask)
197
+
198
+ np.testing.assert_array_equal(np.array(arange),
199
+ np.arange(0, 32, dtype=np.int32))
200
+
201
+ # 3. ===== Test structured_decode_fn =====
202
+ # This function is jitted, so we call it with the device arrays
203
+ modified_logits = self.runner.structured_decoding_manager.structured_decode_fn(
204
+ require_struct_decoding, grammar_bitmask, mock_logits_device,
205
+ arange)
206
+
207
+ modified_logits_cpu = np.array(modified_logits)
208
+
209
+ # Assertions for structured_decode_fn
210
+ # Logits for req-1 (index 0) should be masked for tokens 32-63
211
+ assert np.all(modified_logits_cpu[0, :32] == 1.0)
212
+ assert np.all(modified_logits_cpu[0, 32:] == -np.inf)
213
+
214
+ # Logits for req-2 (index 1) should be unchanged
215
+ np.testing.assert_array_equal(modified_logits_cpu[1],
216
+ np.ones(self.runner.vocab_size))
217
+
218
+ # Logits for req-3 (index 2) should be masked for tokens 0-31
219
+ assert np.all(modified_logits_cpu[2, :32] == -np.inf)
220
+ assert np.all(modified_logits_cpu[2, 32:] == 1.0)