tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +45 -15
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +41 -16
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  240. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,395 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # test_block_table_jax.py
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import numpy as np
20
+ import pytest
21
+
22
+
23
+ def cdiv(a: int, b: int) -> int:
24
+ """Ceiling division: (a + b - 1) // b."""
25
+ return (a + b - 1) // b
26
+
27
+
28
+ class BlockTable:
29
+ """A JAX-compatible BlockTable for managing memory blocks."""
30
+
31
+ def __init__(
32
+ self,
33
+ max_num_reqs: int,
34
+ max_num_blocks_per_req: int,
35
+ max_num_batched_tokens: int,
36
+ pin_memory: bool, # Note: pin_memory is not used in JAX
37
+ ):
38
+ self.max_num_reqs = max_num_reqs
39
+ self.max_num_blocks_per_req = max_num_blocks_per_req
40
+ self.block_table = jnp.zeros((max_num_reqs, max_num_blocks_per_req),
41
+ dtype=jnp.int32)
42
+ self.block_table_cpu = np.zeros((max_num_reqs, max_num_blocks_per_req),
43
+ dtype=np.int32)
44
+ self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
45
+
46
+ def append_row(self, block_ids: list[int], row_idx: int) -> None:
47
+ if not block_ids:
48
+ return
49
+ num_blocks = len(block_ids)
50
+ start = self.num_blocks_per_row[row_idx]
51
+ self.num_blocks_per_row[row_idx] += num_blocks
52
+ self.block_table_cpu[row_idx, start:start + num_blocks] = block_ids
53
+
54
+ def add_row(self, block_ids: list[int], row_idx: int) -> None:
55
+ self.num_blocks_per_row[row_idx] = 0
56
+ # Clear the row for a clean overwrite
57
+ self.block_table_cpu[row_idx].fill(0)
58
+ self.append_row(block_ids, row_idx)
59
+
60
+ def move_row(self, src: int, tgt: int) -> None:
61
+ num_blocks = self.num_blocks_per_row[src]
62
+ self.block_table_cpu[tgt, :num_blocks] = self.block_table_cpu[
63
+ src, :num_blocks]
64
+ # Clear the rest of the target row to avoid stale data
65
+ self.block_table_cpu[tgt, num_blocks:].fill(0)
66
+ self.num_blocks_per_row[tgt] = num_blocks
67
+
68
+ def swap_row(self, src: int, tgt: int) -> None:
69
+ self.num_blocks_per_row[[src,
70
+ tgt]] = self.num_blocks_per_row[[tgt, src]]
71
+ self.block_table_cpu[[src, tgt]] = self.block_table_cpu[[tgt, src]]
72
+
73
+ def commit(self, num_reqs: int) -> None:
74
+ """Corrected commit for JAX immutability."""
75
+ self.block_table = self.block_table.at[:num_reqs].set(
76
+ self.block_table_cpu[:num_reqs])
77
+
78
+ def clear(self) -> None:
79
+ """Corrected clear for JAX immutability and completeness."""
80
+ self.block_table = jnp.zeros_like(self.block_table)
81
+ self.block_table_cpu.fill(0)
82
+ self.num_blocks_per_row.fill(0)
83
+
84
+ def get_device_tensor(self) -> jax.Array:
85
+ return self.block_table
86
+
87
+ def get_cpu_tensor(self) -> np.ndarray:
88
+ return self.block_table_cpu
89
+
90
+
91
+ class MultiGroupBlockTable:
92
+ """Manages BlockTables for each KV cache group."""
93
+
94
+ def __init__(
95
+ self,
96
+ max_num_reqs: int,
97
+ max_model_len: int,
98
+ max_num_batched_tokens: int,
99
+ pin_memory: bool,
100
+ block_sizes: list[int],
101
+ ) -> None:
102
+ self.block_tables = [
103
+ BlockTable(
104
+ max_num_reqs,
105
+ cdiv(max_model_len, block_size),
106
+ max_num_batched_tokens,
107
+ pin_memory,
108
+ ) for block_size in block_sizes
109
+ ]
110
+
111
+ def append_row(self, block_ids: list[list[int]], row_idx: int) -> None:
112
+ for i, block_table in enumerate(self.block_tables):
113
+ block_table.append_row(block_ids[i], row_idx)
114
+
115
+ def add_row(self, block_ids: list[list[int]], row_idx: int) -> None:
116
+ for i, block_table in enumerate(self.block_tables):
117
+ block_table.add_row(block_ids[i], row_idx)
118
+
119
+ def move_row(self, src: int, tgt: int) -> None:
120
+ for block_table in self.block_tables:
121
+ block_table.move_row(src, tgt)
122
+
123
+ def swap_row(self, src: int, tgt: int) -> None:
124
+ for block_table in self.block_tables:
125
+ block_table.swap_row(src, tgt)
126
+
127
+ def commit(self, num_reqs: int) -> None:
128
+ for block_table in self.block_tables:
129
+ block_table.commit(num_reqs)
130
+
131
+ def clear(self) -> None:
132
+ for block_table in self.block_tables:
133
+ block_table.clear()
134
+
135
+ def __getitem__(self, idx: int) -> "BlockTable":
136
+ return self.block_tables[idx]
137
+
138
+
139
+ # --- Pytest Fixtures ---
140
+
141
+
142
+ @pytest.fixture
143
+ def block_table_params():
144
+ """Provides common parameters for creating a BlockTable."""
145
+ return {
146
+ "max_num_reqs": 8,
147
+ "max_num_blocks_per_req": 16,
148
+ "max_num_batched_tokens": 8 * 16,
149
+ "pin_memory": False,
150
+ }
151
+
152
+
153
+ @pytest.fixture
154
+ def block_table(block_table_params):
155
+ """Provides a fresh BlockTable instance for each test."""
156
+ return BlockTable(**block_table_params)
157
+
158
+
159
+ # --- Test Cases ---
160
+
161
+ ##
162
+ ## BlockTable Tests
163
+ ##
164
+
165
+
166
+ class TestBlockTable:
167
+ """Tests for the single BlockTable class."""
168
+
169
+ def test_init(self, block_table, block_table_params):
170
+ """Test constructor and initial state."""
171
+ bt = block_table
172
+ params = block_table_params
173
+
174
+ assert bt.max_num_reqs == params["max_num_reqs"]
175
+ assert bt.max_num_blocks_per_req == params["max_num_blocks_per_req"]
176
+
177
+ # Check CPU table
178
+ assert bt.block_table_cpu.shape == (
179
+ params["max_num_reqs"],
180
+ params["max_num_blocks_per_req"],
181
+ )
182
+ assert bt.block_table_cpu.dtype == np.int32
183
+ np.testing.assert_array_equal(bt.block_table_cpu, 0)
184
+
185
+ # Check device table
186
+ assert bt.block_table.shape == (
187
+ params["max_num_reqs"],
188
+ params["max_num_blocks_per_req"],
189
+ )
190
+ assert bt.block_table.dtype == jnp.int32
191
+ np.testing.assert_array_equal(np.array(bt.block_table), 0)
192
+
193
+ # Check block counter per row
194
+ assert bt.num_blocks_per_row.shape == (params["max_num_reqs"], )
195
+ np.testing.assert_array_equal(bt.num_blocks_per_row, 0)
196
+
197
+ def test_add_and_append_row(self, block_table):
198
+ """Test adding and appending blocks to a row."""
199
+ # Append to row 0
200
+ block_table.append_row([1, 2, 3], row_idx=0)
201
+ assert block_table.num_blocks_per_row[0] == 3
202
+ np.testing.assert_array_equal(block_table.block_table_cpu[0, :3],
203
+ [1, 2, 3])
204
+
205
+ # Append more to row 0
206
+ block_table.append_row([4, 5], row_idx=0)
207
+ assert block_table.num_blocks_per_row[0] == 5
208
+ np.testing.assert_array_equal(block_table.block_table_cpu[0, :5],
209
+ [1, 2, 3, 4, 5])
210
+
211
+ # Add (overwrite) row 1
212
+ block_table.add_row([10, 11], row_idx=1)
213
+ assert block_table.num_blocks_per_row[1] == 2
214
+ np.testing.assert_array_equal(block_table.block_table_cpu[1, :2],
215
+ [10, 11])
216
+
217
+ # Add (overwrite) row 0
218
+ block_table.add_row([6, 7, 8, 9], row_idx=0)
219
+ assert block_table.num_blocks_per_row[0] == 4
220
+ np.testing.assert_array_equal(block_table.block_table_cpu[0, :4],
221
+ [6, 7, 8, 9])
222
+ assert block_table.block_table_cpu[
223
+ 0, 4] == 0 # Ensure rest of row is clear
224
+
225
+ def test_move_row(self, block_table):
226
+ """Test moving a row's content."""
227
+ block_table.add_row([10, 20, 30], row_idx=2)
228
+ block_table.add_row([99], row_idx=5) # Pre-existing data
229
+
230
+ block_table.move_row(src=2, tgt=5)
231
+
232
+ # Check target row
233
+ assert block_table.num_blocks_per_row[5] == 3
234
+ np.testing.assert_array_equal(block_table.get_cpu_tensor()[5, :3],
235
+ [10, 20, 30])
236
+ assert block_table.get_cpu_tensor()[
237
+ 5, 3] == 0 # Check old data is cleared
238
+
239
+ # Check source row (should be unchanged)
240
+ assert block_table.num_blocks_per_row[2] == 3
241
+ np.testing.assert_array_equal(block_table.get_cpu_tensor()[2, :3],
242
+ [10, 20, 30])
243
+
244
+ def test_swap_row(self, block_table):
245
+ """Test swapping two rows."""
246
+ row_2_data = [10, 20, 30]
247
+ row_5_data = [99, 88]
248
+ block_table.add_row(row_2_data, row_idx=2)
249
+ block_table.add_row(row_5_data, row_idx=5)
250
+
251
+ block_table.swap_row(src=2, tgt=5)
252
+
253
+ # Check that data and counts are swapped
254
+ assert block_table.num_blocks_per_row[2] == 2
255
+ assert block_table.num_blocks_per_row[5] == 3
256
+ np.testing.assert_array_equal(block_table.block_table_cpu[2, :2],
257
+ row_5_data)
258
+ np.testing.assert_array_equal(block_table.block_table_cpu[5, :3],
259
+ row_2_data)
260
+
261
+ def test_commit(self, block_table):
262
+ """Test committing the CPU table to the JAX device table."""
263
+ block_table.add_row([1, 2, 3], row_idx=0)
264
+ block_table.add_row([4, 5], row_idx=1)
265
+ num_reqs_to_commit = 2
266
+
267
+ # Before commit, device tensor is all zeros
268
+ np.testing.assert_array_equal(
269
+ np.array(block_table.get_device_tensor()), 0)
270
+
271
+ block_table.commit(num_reqs_to_commit)
272
+ device_table = np.array(block_table.get_device_tensor())
273
+
274
+ # After commit, device tensor should match committed part of CPU tensor
275
+ np.testing.assert_array_equal(
276
+ device_table[:num_reqs_to_commit],
277
+ block_table.get_cpu_tensor()[:num_reqs_to_commit],
278
+ )
279
+ # The rest of the device tensor should still be zero
280
+ np.testing.assert_array_equal(device_table[num_reqs_to_commit:], 0)
281
+
282
+ def test_clear(self, block_table):
283
+ """Test clearing all table data."""
284
+ block_table.add_row([1, 2, 3], row_idx=0)
285
+ block_table.commit(num_reqs=1)
286
+
287
+ # Pre-clear check
288
+ assert np.any(block_table.get_cpu_tensor())
289
+ assert jnp.any(block_table.get_device_tensor())
290
+ assert np.any(block_table.num_blocks_per_row)
291
+
292
+ block_table.clear()
293
+
294
+ # Post-clear check
295
+ np.testing.assert_array_equal(block_table.get_cpu_tensor(), 0)
296
+ np.testing.assert_array_equal(
297
+ np.array(block_table.get_device_tensor()), 0)
298
+ np.testing.assert_array_equal(block_table.num_blocks_per_row, 0)
299
+
300
+
301
+ # ------------------------------------
302
+ # MultiGroupBlockTable Tests
303
+ # ------------------------------------
304
+
305
+
306
+ class TestMultiGroupBlockTable:
307
+ """Tests for the MultiGroupBlockTable class."""
308
+
309
+ @pytest.fixture
310
+ def multi_table_params(self):
311
+ return {
312
+ "max_num_reqs": 4,
313
+ "max_model_len": 32,
314
+ "max_num_batched_tokens": 4 * 32,
315
+ "pin_memory": False,
316
+ "block_sizes": [16, 8], # Two groups
317
+ }
318
+
319
+ @pytest.fixture
320
+ def multi_table(self, multi_table_params):
321
+ return MultiGroupBlockTable(**multi_table_params)
322
+
323
+ def test_init(self, multi_table, multi_table_params):
324
+ """Test constructor and initial state of multiple tables."""
325
+ params = multi_table_params
326
+ assert len(multi_table.block_tables) == len(params["block_sizes"])
327
+ assert isinstance(multi_table[0], BlockTable)
328
+ assert isinstance(multi_table[1], BlockTable)
329
+
330
+ # Check that max_num_blocks_per_req is calculated correctly
331
+ assert multi_table[0].max_num_blocks_per_req == cdiv(
332
+ params["max_model_len"], params["block_sizes"][0]) # 32 / 16 = 2
333
+ assert multi_table[1].max_num_blocks_per_req == cdiv(
334
+ params["max_model_len"], params["block_sizes"][1]) # 32 / 8 = 4
335
+
336
+ def test_add_row(self, multi_table):
337
+ """Test add_row across multiple tables."""
338
+ block_ids = [[101, 102], [201, 202, 203]]
339
+ multi_table.add_row(block_ids, row_idx=0)
340
+
341
+ # Check table 0
342
+ assert multi_table[0].num_blocks_per_row[0] == 2
343
+ np.testing.assert_array_equal(multi_table[0].get_cpu_tensor()[0, :2],
344
+ block_ids[0])
345
+
346
+ # Check table 1
347
+ assert multi_table[1].num_blocks_per_row[0] == 3
348
+ np.testing.assert_array_equal(multi_table[1].get_cpu_tensor()[0, :3],
349
+ block_ids[1])
350
+
351
+ def test_swap_row(self, multi_table):
352
+ """Test swap_row across multiple tables."""
353
+ row1_data = [[11], [11, 22]]
354
+ row3_data = [[33], [33, 44, 55]]
355
+ multi_table.add_row(row1_data, row_idx=1)
356
+ multi_table.add_row(row3_data, row_idx=3)
357
+
358
+ multi_table.swap_row(src=1, tgt=3)
359
+
360
+ # Check row 1 now has row 3's data
361
+ assert multi_table[0].num_blocks_per_row[1] == 1
362
+ np.testing.assert_array_equal(multi_table[0].get_cpu_tensor()[1, :1],
363
+ row3_data[0])
364
+ assert multi_table[1].num_blocks_per_row[1] == 3
365
+ np.testing.assert_array_equal(multi_table[1].get_cpu_tensor()[1, :3],
366
+ row3_data[1])
367
+
368
+ # Check row 3 now has row 1's data
369
+ assert multi_table[0].num_blocks_per_row[3] == 1
370
+ np.testing.assert_array_equal(multi_table[0].get_cpu_tensor()[3, :1],
371
+ row1_data[0])
372
+ assert multi_table[1].num_blocks_per_row[3] == 2
373
+ np.testing.assert_array_equal(multi_table[1].get_cpu_tensor()[3, :2],
374
+ row1_data[1])
375
+
376
+ def test_commit_and_clear(self, multi_table):
377
+ """Test commit and clear across multiple tables."""
378
+ multi_table.add_row([[1], [1, 2]], row_idx=0)
379
+ multi_table.commit(num_reqs=1)
380
+
381
+ # Check commit worked for all tables
382
+ for table in multi_table.block_tables:
383
+ assert jnp.any(table.get_device_tensor())
384
+ device_table = np.array(table.get_device_tensor())
385
+ cpu_table = table.get_cpu_tensor()
386
+ np.testing.assert_array_equal(device_table, cpu_table)
387
+
388
+ multi_table.clear()
389
+
390
+ # Check clear worked for all tables
391
+ for table in multi_table.block_tables:
392
+ np.testing.assert_array_equal(table.get_cpu_tensor(), 0)
393
+ np.testing.assert_array_equal(np.array(table.get_device_tensor()),
394
+ 0)
395
+ np.testing.assert_array_equal(table.num_blocks_per_row, 0)
@@ -0,0 +1,226 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import pytest
17
+ from vllm.sampling_params import SamplingParams
18
+
19
+ from tpu_inference.runner.input_batch import CachedRequestState, InputBatch
20
+
21
+ # Default parameters for creating InputBatch instances in tests
22
+ MAX_NUM_REQS = 8
23
+ MAX_MODEL_LEN = 1024
24
+ MAX_NUM_BATCHED_TOKENS = 2048
25
+ VOCAB_SIZE = 32000
26
+ BLOCK_SIZES = [16]
27
+
28
+
29
+ @pytest.fixture
30
+ def input_batch():
31
+ """Provides a clean InputBatch instance for each test."""
32
+ return InputBatch(
33
+ max_num_reqs=MAX_NUM_REQS,
34
+ max_model_len=MAX_MODEL_LEN,
35
+ max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS,
36
+ pin_memory=False,
37
+ vocab_size=VOCAB_SIZE,
38
+ block_sizes=BLOCK_SIZES,
39
+ is_spec_decode=True,
40
+ )
41
+
42
+
43
+ def create_dummy_request(req_id: str,
44
+ prompt_len: int = 10,
45
+ output_len: int = 5,
46
+ sampling_params: SamplingParams = None,
47
+ block_ids=None) -> CachedRequestState:
48
+ """Helper function to create a CachedRequestState instance."""
49
+ if sampling_params is None:
50
+ sampling_params = SamplingParams(temperature=0.8, top_p=0.9, top_k=50)
51
+
52
+ prompt_token_ids = list(range(prompt_len))
53
+ output_token_ids = list(range(prompt_len, prompt_len + output_len))
54
+
55
+ if block_ids is None:
56
+ # Create dummy block ids based on length
57
+ num_blocks = (prompt_len + output_len + BLOCK_SIZES[0] -
58
+ 1) // BLOCK_SIZES[0]
59
+ block_ids = [[i] for i in range(1, num_blocks + 1)]
60
+
61
+ return CachedRequestState(
62
+ req_id=req_id,
63
+ prompt_token_ids=prompt_token_ids,
64
+ mm_features=[],
65
+ sampling_params=sampling_params,
66
+ pooling_params=None,
67
+ block_ids=block_ids,
68
+ num_computed_tokens=0,
69
+ lora_request=None,
70
+ output_token_ids=output_token_ids,
71
+ )
72
+
73
+
74
+ def test_initialization(input_batch: InputBatch):
75
+ """Tests if the InputBatch is initialized with correct default values."""
76
+ assert input_batch.max_num_reqs == MAX_NUM_REQS
77
+ assert input_batch.num_reqs == 0
78
+ assert len(input_batch.req_ids) == 0
79
+ assert not input_batch.req_id_to_index
80
+ assert input_batch.all_greedy
81
+ assert input_batch.is_spec_decode
82
+
83
+
84
+ def test_add_request(input_batch: InputBatch):
85
+ """Tests adding a single request to the batch."""
86
+ req = create_dummy_request("req-1", prompt_len=20, output_len=4)
87
+ input_batch.add_request(req)
88
+
89
+ assert input_batch.num_reqs == 1
90
+ assert "req-1" in input_batch.req_id_to_index
91
+ assert input_batch.req_id_to_index["req-1"] == 0
92
+ assert input_batch.req_ids == ["req-1"]
93
+ assert len(input_batch.spec_decode_unsupported_reqs) == 0
94
+
95
+ # Verify token data
96
+ assert input_batch.num_prompt_tokens[0] == 20
97
+ assert input_batch.num_tokens[0] == 24
98
+ assert input_batch.num_tokens_no_spec[0] == 24
99
+ expected_tokens = np.array(req.prompt_token_ids + req.output_token_ids)
100
+ np.testing.assert_array_equal(input_batch.token_ids_cpu[0, :24],
101
+ expected_tokens)
102
+
103
+ # Verify sampling params
104
+ assert input_batch.temperature_cpu[0] == 0.8
105
+ assert input_batch.top_p_cpu[0] == 0.9
106
+ assert input_batch.top_k_cpu[0] == 50
107
+
108
+
109
+ def test_add_multiple_requests(input_batch: InputBatch):
110
+ """Tests adding multiple requests and checks their indices."""
111
+ req1 = create_dummy_request("req-1")
112
+ req2 = create_dummy_request("req-2")
113
+
114
+ input_batch.add_request(req1)
115
+ input_batch.add_request(req2)
116
+
117
+ assert input_batch.num_reqs == 2
118
+ assert input_batch.req_ids == ["req-1", "req-2"]
119
+ assert input_batch.req_id_to_index["req-1"] == 0
120
+ assert input_batch.req_id_to_index["req-2"] == 1
121
+ assert input_batch.num_tokens[1] == len(req2.prompt_token_ids) + len(
122
+ req2.output_token_ids)
123
+ assert input_batch.num_tokens_no_spec[1] == len(
124
+ req2.prompt_token_ids) + len(req2.output_token_ids)
125
+
126
+
127
+ def test_remove_request(input_batch: InputBatch):
128
+ """Tests removing a request, which leaves a gap in the batch."""
129
+ req1 = create_dummy_request("req-1")
130
+ req2 = create_dummy_request("req-2")
131
+ input_batch.add_request(req1)
132
+ input_batch.add_request(req2)
133
+
134
+ removed_index = input_batch.remove_request("req-1")
135
+
136
+ assert removed_index == 0
137
+ assert input_batch.num_reqs == 1
138
+ assert "req-1" not in input_batch.req_id_to_index
139
+ assert input_batch._req_ids[0] is None # Slot is now empty
140
+ assert input_batch._req_ids[1] == "req-2"
141
+ assert "req-1" not in input_batch.greedy_reqs
142
+
143
+
144
+ def test_condense(input_batch: InputBatch):
145
+ """Tests condensing the batch after removing requests."""
146
+ reqs = [create_dummy_request(f"req-{i}") for i in range(4)]
147
+ for req in reqs:
148
+ input_batch.add_request(req)
149
+
150
+ # Remove requests from the middle and start
151
+ input_batch.remove_request("req-1")
152
+ input_batch.remove_request("req-0")
153
+
154
+ # Before condense: [None, None, "req-2", "req-3"]
155
+ assert input_batch._req_ids[0] is None
156
+ assert input_batch._req_ids[1] is None
157
+ assert input_batch.num_reqs == 2
158
+
159
+ # Condense should move req-2 and req-3 to the front
160
+ empty_indices = sorted([0, 1], reverse=True)
161
+ input_batch.condense(empty_indices)
162
+
163
+ assert input_batch.num_reqs == 2
164
+ assert len(input_batch.req_ids) == 2
165
+ assert input_batch.req_ids == ["req-3", "req-2"]
166
+ assert input_batch.req_id_to_index["req-2"] == 1
167
+ assert input_batch.req_id_to_index["req-3"] == 0
168
+
169
+ # Check if a property was moved correctly
170
+ assert input_batch.num_tokens[0] == len(reqs[2].prompt_token_ids) + len(
171
+ reqs[2].output_token_ids)
172
+ assert input_batch.num_tokens_no_spec[0] == len(
173
+ reqs[2].prompt_token_ids) + len(reqs[2].output_token_ids)
174
+
175
+
176
+ def test_swap_states(input_batch: InputBatch):
177
+ """Tests swapping the states of two requests."""
178
+ req1 = create_dummy_request("req-1", prompt_len=10, output_len=1)
179
+ req2 = create_dummy_request("req-2",
180
+ prompt_len=20,
181
+ output_len=2,
182
+ sampling_params=SamplingParams(top_p=0.5))
183
+
184
+ input_batch.add_request(req1)
185
+ input_batch.add_request(req2)
186
+
187
+ # Capture states before swap
188
+ req1_tokens_before = input_batch.token_ids_cpu[0].copy()
189
+ req2_tokens_before = input_batch.token_ids_cpu[1].copy()
190
+ req1_top_p_before = input_batch.top_p_cpu[0]
191
+ req2_top_p_before = input_batch.top_p_cpu[1]
192
+
193
+ input_batch.swap_states(0, 1)
194
+
195
+ # Check IDs and mappings
196
+ assert input_batch.req_ids == ["req-2", "req-1"]
197
+ assert input_batch.req_id_to_index["req-1"] == 1
198
+ assert input_batch.req_id_to_index["req-2"] == 0
199
+
200
+ # Check swapped data
201
+ assert input_batch.top_p_cpu[0] == req2_top_p_before
202
+ assert input_batch.top_p_cpu[1] == req1_top_p_before
203
+ np.testing.assert_array_equal(input_batch.token_ids_cpu[0],
204
+ req2_tokens_before)
205
+ np.testing.assert_array_equal(input_batch.token_ids_cpu[1],
206
+ req1_tokens_before)
207
+
208
+
209
+ def test_all_greedy_property(input_batch: InputBatch):
210
+ """Tests the `all_greedy` property."""
211
+ # Initially true
212
+ assert input_batch.all_greedy
213
+
214
+ # Add a greedy request, still true
215
+ req_greedy = create_dummy_request(
216
+ "req-g", sampling_params=SamplingParams(temperature=0.0))
217
+ input_batch.add_request(req_greedy)
218
+ assert input_batch.all_greedy
219
+
220
+ # Manually add a random request for testing purposes
221
+ input_batch.random_reqs.add("req-r")
222
+ assert not input_batch.all_greedy
223
+
224
+ # Remove it, should be true again
225
+ input_batch.random_reqs.remove("req-r")
226
+ assert input_batch.all_greedy