tpu-inference 0.11.1.dev202511220812__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (257) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +317 -34
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +406 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +320 -0
  64. tests/layers/vllm/test_unquantized.py +662 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +26 -6
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +110 -12
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +2 -45
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +15 -10
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +92 -8
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +25 -4
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +807 -230
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +218 -137
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +25 -12
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  154. tpu_inference/layers/common/quant_methods.py +15 -0
  155. tpu_inference/layers/common/quantization.py +282 -0
  156. tpu_inference/layers/common/sharding.py +32 -9
  157. tpu_inference/layers/common/utils.py +94 -0
  158. tpu_inference/layers/jax/__init__.py +13 -0
  159. tpu_inference/layers/jax/attention/__init__.py +13 -0
  160. tpu_inference/layers/jax/attention/attention.py +19 -6
  161. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  162. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  163. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  164. tpu_inference/layers/jax/base.py +14 -0
  165. tpu_inference/layers/jax/constants.py +13 -0
  166. tpu_inference/layers/jax/layers.py +14 -0
  167. tpu_inference/layers/jax/misc.py +14 -0
  168. tpu_inference/layers/jax/moe/__init__.py +13 -0
  169. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  170. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  171. tpu_inference/layers/jax/moe/moe.py +43 -3
  172. tpu_inference/layers/jax/pp_utils.py +53 -0
  173. tpu_inference/layers/jax/rope.py +14 -0
  174. tpu_inference/layers/jax/rope_interface.py +14 -0
  175. tpu_inference/layers/jax/sample/__init__.py +13 -0
  176. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  177. tpu_inference/layers/jax/sample/sampling.py +15 -1
  178. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  179. tpu_inference/layers/jax/transformer_block.py +14 -0
  180. tpu_inference/layers/vllm/__init__.py +13 -0
  181. tpu_inference/layers/vllm/attention.py +4 -4
  182. tpu_inference/layers/vllm/fused_moe.py +101 -494
  183. tpu_inference/layers/vllm/linear.py +64 -0
  184. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  185. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  186. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  187. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  188. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  189. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  191. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +23 -8
  192. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +172 -176
  193. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  194. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  195. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  196. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +42 -25
  197. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  198. tpu_inference/layers/vllm/quantization/mxfp4.py +137 -178
  199. tpu_inference/layers/vllm/quantization/unquantized.py +157 -233
  200. tpu_inference/lora/__init__.py +13 -0
  201. tpu_inference/lora/torch_lora_ops.py +8 -13
  202. tpu_inference/models/__init__.py +13 -0
  203. tpu_inference/models/common/__init__.py +13 -0
  204. tpu_inference/models/common/model_loader.py +112 -35
  205. tpu_inference/models/jax/__init__.py +13 -0
  206. tpu_inference/models/jax/deepseek_v3.py +267 -157
  207. tpu_inference/models/jax/gpt_oss.py +26 -10
  208. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  209. tpu_inference/models/jax/llama3.py +99 -36
  210. tpu_inference/models/jax/llama4.py +14 -0
  211. tpu_inference/models/jax/llama_eagle3.py +18 -5
  212. tpu_inference/models/jax/llama_guard_4.py +15 -1
  213. tpu_inference/models/jax/qwen2.py +17 -2
  214. tpu_inference/models/jax/qwen2_5_vl.py +179 -51
  215. tpu_inference/models/jax/qwen3.py +17 -2
  216. tpu_inference/models/jax/utils/__init__.py +13 -0
  217. tpu_inference/models/jax/utils/file_utils.py +14 -0
  218. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  219. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  220. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +92 -32
  221. tpu_inference/models/jax/utils/weight_utils.py +234 -155
  222. tpu_inference/models/vllm/__init__.py +13 -0
  223. tpu_inference/models/vllm/vllm_model_wrapper.py +32 -8
  224. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  225. tpu_inference/platforms/__init__.py +14 -0
  226. tpu_inference/platforms/tpu_platform.py +51 -72
  227. tpu_inference/runner/__init__.py +13 -0
  228. tpu_inference/runner/compilation_manager.py +180 -80
  229. tpu_inference/runner/kv_cache.py +54 -20
  230. tpu_inference/runner/kv_cache_manager.py +55 -33
  231. tpu_inference/runner/lora_utils.py +16 -1
  232. tpu_inference/runner/multimodal_manager.py +16 -2
  233. tpu_inference/runner/persistent_batch_manager.py +54 -2
  234. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  235. tpu_inference/runner/structured_decoding_manager.py +16 -3
  236. tpu_inference/runner/tpu_runner.py +124 -61
  237. tpu_inference/runner/utils.py +2 -2
  238. tpu_inference/spec_decode/__init__.py +13 -0
  239. tpu_inference/spec_decode/jax/__init__.py +13 -0
  240. tpu_inference/spec_decode/jax/eagle3.py +84 -22
  241. tpu_inference/tpu_info.py +14 -0
  242. tpu_inference/utils.py +72 -44
  243. tpu_inference/worker/__init__.py +13 -0
  244. tpu_inference/worker/tpu_worker.py +66 -52
  245. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +8 -9
  246. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  247. tpu_inference/layers/vllm/linear_common.py +0 -186
  248. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  249. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  250. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  251. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  252. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  253. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  254. tpu_inference-0.11.1.dev202511220812.dist-info/RECORD +0 -174
  255. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  256. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  257. {tpu_inference-0.11.1.dev202511220812.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
tests/core/test_init.py CHANGED
@@ -1,3 +1,17 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
  import importlib
2
16
  import unittest
3
17
  from unittest.mock import patch
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,120 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections import namedtuple
16
+
17
+ import pytest
18
+
19
+ from tpu_inference.distributed.utils import get_device_topology_order_id
20
+
21
+ # Mock TpuDevice object to simulate the real one.
22
+ TpuDevice = namedtuple('TpuDevice',
23
+ ['id', 'process_index', 'coords', 'core_on_chip'])
24
+
25
+
26
+ def test_get_device_topology_order_id():
27
+ """
28
+ Tests the get_device_topology_order_id function with a mock topology.
29
+ """
30
+ # V7x
31
+ global_devices = [
32
+ TpuDevice(id=0, process_index=0, coords=(0, 0, 0), core_on_chip=0),
33
+ TpuDevice(id=1, process_index=0, coords=(0, 0, 0), core_on_chip=1),
34
+ TpuDevice(id=2, process_index=0, coords=(1, 0, 0), core_on_chip=0),
35
+ TpuDevice(id=3, process_index=0, coords=(1, 0, 0), core_on_chip=1),
36
+ TpuDevice(id=4, process_index=0, coords=(0, 1, 0), core_on_chip=0),
37
+ TpuDevice(id=5, process_index=0, coords=(0, 1, 0), core_on_chip=1),
38
+ TpuDevice(id=6, process_index=0, coords=(1, 1, 0), core_on_chip=0),
39
+ TpuDevice(id=7, process_index=0, coords=(1, 1, 0), core_on_chip=1),
40
+ TpuDevice(id=8, process_index=1, coords=(0, 0, 1), core_on_chip=0),
41
+ TpuDevice(id=9, process_index=1, coords=(0, 0, 1), core_on_chip=1),
42
+ TpuDevice(id=10, process_index=1, coords=(1, 0, 1), core_on_chip=0),
43
+ TpuDevice(id=11, process_index=1, coords=(1, 0, 1), core_on_chip=1),
44
+ TpuDevice(id=12, process_index=1, coords=(0, 1, 1), core_on_chip=0),
45
+ TpuDevice(id=13, process_index=1, coords=(0, 1, 1), core_on_chip=1),
46
+ TpuDevice(id=14, process_index=1, coords=(1, 1, 1), core_on_chip=0),
47
+ TpuDevice(id=15, process_index=1, coords=(1, 1, 1), core_on_chip=1),
48
+ ]
49
+
50
+ local_devices_1 = global_devices[:8]
51
+ local_devices_2 = global_devices[8:]
52
+
53
+ assert get_device_topology_order_id(local_devices_1, global_devices) == 0
54
+ assert get_device_topology_order_id(local_devices_2, global_devices) == 1
55
+
56
+ # Test with unsorted in global_devices
57
+ shuffled_z_global_devices = [
58
+ TpuDevice(id=8, process_index=1, coords=(0, 0, 1), core_on_chip=0),
59
+ TpuDevice(id=0, process_index=0, coords=(0, 0, 0), core_on_chip=0),
60
+ ]
61
+ local_devices_z1 = [
62
+ TpuDevice(id=8, process_index=1, coords=(0, 0, 1), core_on_chip=0)
63
+ ]
64
+ local_devices_z0 = [
65
+ TpuDevice(id=0, process_index=0, coords=(0, 0, 0), core_on_chip=0)
66
+ ]
67
+
68
+ assert get_device_topology_order_id(local_devices_z0,
69
+ shuffled_z_global_devices) == 0
70
+ assert get_device_topology_order_id(local_devices_z1,
71
+ shuffled_z_global_devices) == 1
72
+
73
+ #v6e
74
+ global_devices = [
75
+ TpuDevice(id=0, process_index=0, coords=(0, 0, 0), core_on_chip=0),
76
+ TpuDevice(id=1, process_index=1, coords=(1, 0, 0), core_on_chip=0),
77
+ TpuDevice(id=2, process_index=2, coords=(0, 1, 0), core_on_chip=0),
78
+ TpuDevice(id=3, process_index=3, coords=(1, 1, 0), core_on_chip=0)
79
+ ]
80
+ local_devices = [
81
+ TpuDevice(id=0, process_index=0, coords=(0, 0, 0), core_on_chip=0)
82
+ ]
83
+ assert get_device_topology_order_id(local_devices, global_devices) == 0
84
+
85
+ local_devices = [
86
+ TpuDevice(id=1, process_index=1, coords=(1, 0, 0), core_on_chip=0)
87
+ ]
88
+ assert get_device_topology_order_id(local_devices, global_devices) == 2
89
+
90
+ local_devices = [
91
+ TpuDevice(id=2, process_index=2, coords=(0, 1, 0), core_on_chip=0)
92
+ ]
93
+ assert get_device_topology_order_id(local_devices, global_devices) == 1
94
+
95
+ local_devices = [
96
+ TpuDevice(id=3, process_index=3, coords=(1, 1, 0), core_on_chip=0)
97
+ ]
98
+ assert get_device_topology_order_id(local_devices, global_devices) == 3
99
+
100
+
101
+ def test_get_device_topology_order_id_empty_local():
102
+ """
103
+ Tests that a ValueError is raised for empty local_devices.
104
+ """
105
+ with pytest.raises(ValueError, match="local_devices cannot be empty"):
106
+ get_device_topology_order_id([], [])
107
+
108
+
109
+ def test_get_device_topology_order_id_not_in_global():
110
+ """
111
+ Tests that a ValueError is raised if local z-coordinate is not in global list.
112
+ """
113
+ global_devices = [
114
+ TpuDevice(id=0, process_index=0, coords=(0, 0, 0), core_on_chip=0),
115
+ ]
116
+ local_devices = [
117
+ TpuDevice(id=1, process_index=1, coords=(0, 0, 1), core_on_chip=0),
118
+ ]
119
+ with pytest.raises(ValueError, match="do not exist in the global device:"):
120
+ get_device_topology_order_id(local_devices, global_devices)
@@ -0,0 +1,478 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+ from unittest.mock import MagicMock, patch
17
+
18
+ from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorRole
19
+ from vllm.v1.request import RequestStatus
20
+
21
+ from tpu_inference.distributed import tpu_connector
22
+
23
+
24
+ class MockVllmConfig:
25
+
26
+ def __init__(self):
27
+ self.kv_transfer_config = MagicMock()
28
+ self.kv_transfer_config.is_kv_producer = True
29
+ self.cache_config = MagicMock()
30
+ self.cache_config.block_size = 16
31
+ self.parallel_config = MagicMock()
32
+
33
+
34
+ @patch("tpu_inference.distributed.tpu_connector.TPUConnectorWorker")
35
+ @patch("tpu_inference.distributed.tpu_connector.TPUConnectorScheduler")
36
+ class TestTPUConnector(unittest.TestCase):
37
+
38
+ def setUp(self):
39
+ self.vllm_config = MockVllmConfig()
40
+
41
+ def test_init_scheduler_role(self, mock_scheduler_cls, mock_worker_cls):
42
+ """
43
+ Tests that TPUConnector initializes the scheduler connector for the
44
+ SCHEDULER role.
45
+ """
46
+ connector = tpu_connector.TPUConnector(self.vllm_config,
47
+ KVConnectorRole.SCHEDULER)
48
+ mock_scheduler_cls.assert_called_once_with(self.vllm_config)
49
+ mock_worker_cls.assert_not_called()
50
+ self.assertIsNotNone(connector.connector_scheduler)
51
+ self.assertIsNone(connector.connector_worker)
52
+
53
+ def test_init_worker_role(self, mock_scheduler_cls, mock_worker_cls):
54
+ """
55
+ Tests that TPUConnector initializes the worker connector for the WORKER
56
+ role.
57
+ """
58
+ connector = tpu_connector.TPUConnector(self.vllm_config,
59
+ KVConnectorRole.WORKER)
60
+ mock_worker_cls.assert_called_once_with(self.vllm_config)
61
+ mock_scheduler_cls.assert_not_called()
62
+ self.assertIsNone(connector.connector_scheduler)
63
+ self.assertIsNotNone(connector.connector_worker)
64
+
65
+ def test_scheduler_methods_are_called(self, mock_scheduler_cls,
66
+ mock_worker_cls):
67
+ """Tests that scheduler-side methods are correctly delegated."""
68
+ mock_scheduler_instance = mock_scheduler_cls.return_value
69
+ connector = tpu_connector.TPUConnector(self.vllm_config,
70
+ KVConnectorRole.SCHEDULER)
71
+
72
+ mock_request = MagicMock()
73
+ mock_blocks = MagicMock()
74
+ mock_scheduler_output = MagicMock()
75
+
76
+ connector.get_num_new_matched_tokens(mock_request, 16)
77
+ mock_scheduler_instance.get_num_new_matched_tokens.assert_called_once_with(
78
+ mock_request, 16)
79
+
80
+ connector.update_state_after_alloc(mock_request, mock_blocks, 16)
81
+ mock_scheduler_instance.update_state_after_alloc.assert_called_once_with(
82
+ mock_request, mock_blocks, 16)
83
+
84
+ connector.build_connector_meta(mock_scheduler_output)
85
+ mock_scheduler_instance.build_connector_meta.assert_called_once_with()
86
+
87
+ connector.request_finished(mock_request, [1, 2])
88
+ mock_scheduler_instance.request_finished.assert_called_once_with(
89
+ mock_request, [1, 2])
90
+
91
+ def test_worker_methods_are_called(self, mock_scheduler_cls,
92
+ mock_worker_cls):
93
+ """Tests that worker-side methods are correctly delegated."""
94
+ mock_worker_instance = mock_worker_cls.return_value
95
+ connector = tpu_connector.TPUConnector(self.vllm_config,
96
+ KVConnectorRole.WORKER)
97
+ connector._connector_metadata = tpu_connector.TPUConnectorMetadata(
98
+ ) # need to set this for start_load_kv
99
+
100
+ mock_runner = MagicMock()
101
+
102
+ connector.register_runner(mock_runner)
103
+ mock_worker_instance.register_runner.assert_called_once_with(
104
+ mock_runner)
105
+
106
+ connector.start_load_kv(None)
107
+ mock_worker_instance.process_send_load.assert_called_once_with(
108
+ connector._connector_metadata)
109
+
110
+ connector.get_finished(set())
111
+ mock_worker_instance.get_finished.assert_called_once_with()
112
+
113
+
114
+ class TestTPUConnectorScheduler(unittest.TestCase):
115
+
116
+ def setUp(self):
117
+ self.vllm_config = MockVllmConfig()
118
+ self.vllm_config.cache_config.block_size = 16
119
+ self.vllm_config.kv_transfer_config.is_kv_producer = False
120
+
121
+ with patch("tpu_inference.distributed.tpu_connector.get_kv_ips",
122
+ return_value="1.1.1.1"), patch(
123
+ "tpu_inference.distributed.tpu_connector.get_kv_ports",
124
+ return_value=12345):
125
+ self.scheduler = tpu_connector.TPUConnectorScheduler(
126
+ self.vllm_config)
127
+
128
+ def test_get_num_new_matched_tokens_producer(self):
129
+ """Tests that producer returns 0 tokens to load."""
130
+ self.scheduler.is_producer = True
131
+ mock_request = MagicMock()
132
+ num_tokens, is_async = self.scheduler.get_num_new_matched_tokens(
133
+ mock_request, 16)
134
+ self.assertEqual(num_tokens, 0)
135
+ self.assertFalse(is_async)
136
+
137
+ def test_get_num_new_matched_tokens_consumer_needs_loading(self):
138
+ """Tests consumer calculates correct number of tokens to load."""
139
+ mock_request = MagicMock()
140
+ mock_request.prompt_token_ids = [0] * 35 # 2 blocks worth, plus some
141
+ num_computed_tokens = 16 # 1 block
142
+ # rounded_down(35) = 32. 32 - 16 = 16.
143
+ expected_tokens = 16
144
+ num_tokens, is_async = self.scheduler.get_num_new_matched_tokens(
145
+ mock_request, num_computed_tokens)
146
+ self.assertEqual(num_tokens, expected_tokens)
147
+ self.assertTrue(is_async)
148
+
149
+ def test_get_num_new_matched_tokens_consumer_no_loading(self):
150
+ """Tests consumer returns 0 if prompt is fully cached."""
151
+ mock_request = MagicMock()
152
+ mock_request.prompt_token_ids = [0] * 31 # less than 2 blocks
153
+ num_computed_tokens = 32 # 2 blocks computed
154
+ expected_tokens = 0
155
+ num_tokens, is_async = self.scheduler.get_num_new_matched_tokens(
156
+ mock_request, num_computed_tokens)
157
+ self.assertEqual(num_tokens, expected_tokens)
158
+ self.assertFalse(is_async)
159
+
160
+ def test_update_state_after_alloc_producer(self):
161
+ """Tests that update_state_after_alloc is a no-op for producers."""
162
+ self.scheduler.is_producer = True
163
+ self.scheduler.update_state_after_alloc(MagicMock(), MagicMock(), 16)
164
+ self.assertEqual(len(self.scheduler.reqs_to_load), 0)
165
+
166
+ def test_update_state_after_alloc_consumer_with_external_tokens(self):
167
+ """
168
+ Tests consumer state is updated when external tokens are needed.
169
+ """
170
+ mock_request = MagicMock()
171
+ mock_request.request_id = "req1"
172
+ mock_request.kv_transfer_params = {
173
+ "uuid": 123,
174
+ "remote_block_ids": [10, 11],
175
+ "remote_host": "2.2.2.2",
176
+ "remote_port": 54321
177
+ }
178
+ mock_blocks = MagicMock()
179
+ mock_blocks.get_block_ids.return_value = [[1, 2]]
180
+ num_external_tokens = 32
181
+
182
+ self.scheduler.update_state_after_alloc(mock_request, mock_blocks,
183
+ num_external_tokens)
184
+
185
+ self.assertIn("req1", self.scheduler.reqs_to_load)
186
+ load_meta = self.scheduler.reqs_to_load["req1"]
187
+ self.assertEqual(load_meta.uuid, 123)
188
+ self.assertEqual(load_meta.local_block_ids, [1, 2])
189
+ self.assertEqual(load_meta.remote_block_ids, [10, 11])
190
+
191
+ def test_update_state_after_alloc_consumer_no_external_tokens(self):
192
+ """
193
+ Tests consumer state is updated for notification when no external
194
+ tokens are needed.
195
+ """
196
+ mock_request = MagicMock()
197
+ mock_request.request_id = "req1"
198
+ mock_request.kv_transfer_params = {
199
+ "uuid": 123,
200
+ "remote_block_ids": [10, 11],
201
+ "remote_host": "2.2.2.2",
202
+ "remote_port": 54321
203
+ }
204
+ mock_blocks = MagicMock()
205
+ num_external_tokens = 0
206
+
207
+ self.scheduler.update_state_after_alloc(mock_request, mock_blocks,
208
+ num_external_tokens)
209
+
210
+ self.assertIn("req1", self.scheduler.reqs_to_load)
211
+ load_meta = self.scheduler.reqs_to_load["req1"]
212
+ self.assertEqual(load_meta.uuid, 123)
213
+ self.assertIsNone(load_meta.local_block_ids)
214
+ self.assertIsNone(load_meta.remote_block_ids)
215
+
216
+ def test_build_connector_meta(self):
217
+ """Tests that metadata is built and state is cleared."""
218
+ self.scheduler.is_producer = True
219
+ self.scheduler.reqs_to_send = {"req1": "meta1"}
220
+ meta = self.scheduler.build_connector_meta()
221
+ self.assertEqual(meta.reqs_to_send, {"req1": "meta1"})
222
+ self.assertEqual(len(self.scheduler.reqs_to_send),
223
+ 0) # check it was cleared
224
+
225
+ self.scheduler.is_producer = False
226
+ self.scheduler.reqs_to_load = {"req2": "meta2"}
227
+ meta = self.scheduler.build_connector_meta()
228
+ self.assertEqual(meta.reqs_to_load, {"req2": "meta2"})
229
+ self.assertEqual(len(self.scheduler.reqs_to_load), 0)
230
+
231
+ def test_request_finished_consumer(self):
232
+ """Tests request_finished is a no-op for consumers."""
233
+ self.scheduler.is_producer = False
234
+ delay_free, params = self.scheduler.request_finished(MagicMock(), [])
235
+ self.assertFalse(delay_free)
236
+ self.assertIsNone(params)
237
+
238
+ @patch("tpu_inference.distributed.tpu_connector.get_uuid",
239
+ return_value=456)
240
+ def test_request_finished_producer_finished_by_length(self, mock_get_uuid):
241
+ """Tests producer logic when a request finishes normally."""
242
+ self.scheduler.is_producer = True
243
+ mock_request = MagicMock()
244
+ mock_request.request_id = "req-finished"
245
+ mock_request.status = RequestStatus.FINISHED_LENGTH_CAPPED
246
+ mock_request.num_computed_tokens = 32 # 2 blocks
247
+ block_ids = [1, 2]
248
+
249
+ delay_free, params = self.scheduler.request_finished(
250
+ mock_request, block_ids)
251
+
252
+ self.assertTrue(delay_free)
253
+ self.assertIn("req-finished", self.scheduler.reqs_to_send)
254
+ send_meta = self.scheduler.reqs_to_send["req-finished"]
255
+ self.assertEqual(send_meta.uuid, 456)
256
+ self.assertEqual(send_meta.local_block_ids, [1, 2])
257
+
258
+ self.assertIsNotNone(params)
259
+ self.assertEqual(params["uuid"], 456)
260
+ self.assertEqual(params["remote_block_ids"], [1, 2])
261
+ self.assertEqual(params["remote_host"], "1.1.1.1")
262
+ self.assertEqual(params["remote_port"], 12345)
263
+
264
+ def test_request_finished_producer_not_finished(self):
265
+ """Tests producer logic when a request is not yet finished."""
266
+ self.scheduler.is_producer = True
267
+ mock_request = MagicMock()
268
+ mock_request.status = RequestStatus.RUNNING # Not finished
269
+ delay_free, params = self.scheduler.request_finished(
270
+ mock_request, [1, 2])
271
+ self.assertFalse(delay_free)
272
+ self.assertIsNone(params)
273
+
274
+ def test_request_finished_producer_prompt_too_short(self):
275
+ """Tests producer logic when prompt is too short to transfer."""
276
+ self.scheduler.is_producer = True
277
+ mock_request = MagicMock()
278
+ mock_request.request_id = "req-short"
279
+ mock_request.status = RequestStatus.FINISHED_LENGTH_CAPPED
280
+ mock_request.num_computed_tokens = 10 # less than a block
281
+ block_ids = [1]
282
+
283
+ delay_free, params = self.scheduler.request_finished(
284
+ mock_request, block_ids)
285
+
286
+ self.assertFalse(delay_free)
287
+ self.assertEqual(params, {})
288
+ self.assertNotIn("req-short", self.scheduler.reqs_to_send)
289
+
290
+
291
+ class TestTPUConnectorWorker(unittest.TestCase):
292
+
293
+ def setUp(self):
294
+ self.vllm_config = MockVllmConfig()
295
+ patchers = {
296
+ "jax":
297
+ patch('tpu_inference.distributed.tpu_connector.jax'),
298
+ "get_host_ip":
299
+ patch('tpu_inference.distributed.tpu_connector.get_host_ip',
300
+ return_value='127.0.0.1'),
301
+ "get_kv_transfer_port":
302
+ patch(
303
+ 'tpu_inference.distributed.tpu_connector.get_kv_transfer_port',
304
+ return_value=10000),
305
+ "get_side_channel_port":
306
+ patch(
307
+ 'tpu_inference.distributed.tpu_connector.get_side_channel_port',
308
+ return_value=20000),
309
+ "start_transfer_server":
310
+ patch(
311
+ 'tpu_inference.distributed.tpu_connector.start_transfer_server'
312
+ ),
313
+ "zmq":
314
+ patch('tpu_inference.distributed.tpu_connector.zmq'),
315
+ "threading":
316
+ patch('tpu_inference.distributed.tpu_connector.threading'),
317
+ "ThreadPoolExecutor":
318
+ patch(
319
+ 'tpu_inference.distributed.tpu_connector.ThreadPoolExecutor'),
320
+ "device_array":
321
+ patch('tpu_inference.distributed.tpu_connector.device_array'),
322
+ "select_from_kv_caches":
323
+ patch(
324
+ 'tpu_inference.distributed.tpu_connector.select_from_kv_caches'
325
+ ),
326
+ "scatter_kv_slices":
327
+ patch('tpu_inference.distributed.tpu_connector.scatter_kv_slices'),
328
+ "time":
329
+ patch('tpu_inference.distributed.tpu_connector.time'),
330
+ "make_zmq_path":
331
+ patch('tpu_inference.distributed.tpu_connector.make_zmq_path'),
332
+ "make_zmq_socket":
333
+ patch('tpu_inference.distributed.tpu_connector.make_zmq_socket'),
334
+ }
335
+ self.all_mocks = {k: p.start() for k, p in patchers.items()}
336
+ self.all_mocks["jax"].local_devices.return_value = [MagicMock()]
337
+ for p in patchers.values():
338
+ self.addCleanup(p.stop)
339
+
340
+ def test_init_producer(self):
341
+ """Tests worker initialization for the producer role."""
342
+ self.vllm_config.kv_transfer_config.is_kv_producer = True
343
+ worker = tpu_connector.TPUConnectorWorker(self.vllm_config)
344
+
345
+ self.all_mocks["zmq"].Context.assert_called_once()
346
+ self.all_mocks["threading"].Thread.assert_called_once()
347
+ self.all_mocks["threading"].Event.assert_called()
348
+ self.all_mocks["ThreadPoolExecutor"].assert_not_called()
349
+ self.assertTrue(worker.is_producer)
350
+
351
+ def test_init_consumer(self):
352
+ """Tests worker initialization for the consumer role."""
353
+ self.vllm_config.kv_transfer_config.is_kv_producer = False
354
+ worker = tpu_connector.TPUConnectorWorker(self.vllm_config)
355
+
356
+ self.all_mocks["zmq"].Context.assert_called_once()
357
+ self.all_mocks["threading"].Thread.assert_not_called()
358
+ self.all_mocks["ThreadPoolExecutor"].assert_called_once_with(
359
+ max_workers=64)
360
+ self.assertFalse(worker.is_producer)
361
+
362
+ def test_register_runner(self):
363
+ """Tests that runner registration correctly sets worker attributes."""
364
+ self.vllm_config.kv_transfer_config.is_kv_producer = False
365
+ worker = tpu_connector.TPUConnectorWorker(self.vllm_config)
366
+
367
+ mock_runner = MagicMock()
368
+ mock_kv_cache_layer = MagicMock()
369
+ mock_kv_cache_layer.shape = [10, 20, 30, 40]
370
+ mock_kv_cache_layer.dtype = 'float32'
371
+ mock_kv_cache_layer.sharding = 'sharding_spec'
372
+ mock_runner.kv_caches = [mock_kv_cache_layer] * 5
373
+ mock_runner.mesh = 'mesh'
374
+
375
+ worker.register_runner(mock_runner)
376
+
377
+ self.all_mocks["start_transfer_server"].assert_called_once()
378
+ self.assertEqual(worker.runner, mock_runner)
379
+ self.assertEqual(worker.mesh, 'mesh')
380
+ self.assertEqual(worker.num_layers, 5)
381
+ self.assertEqual(worker.shape, [10, 20, 30, 40])
382
+ self.assertEqual(worker.dtype, 'float32')
383
+ self.assertEqual(worker.sharding, 'sharding_spec')
384
+
385
+ def test_process_send_load_for_producer(self):
386
+ """Tests process_send_load for the producer role."""
387
+ self.vllm_config.kv_transfer_config.is_kv_producer = True
388
+ worker = tpu_connector.TPUConnectorWorker(self.vllm_config)
389
+ worker._prepare_kv_and_wait = MagicMock()
390
+
391
+ meta = tpu_connector.TPUConnectorMetadata()
392
+ send_meta = tpu_connector.SendMeta(uuid=1,
393
+ local_block_ids=[1],
394
+ expiration_time=123)
395
+ meta.reqs_to_send = {"req1": send_meta}
396
+
397
+ worker.process_send_load(meta)
398
+
399
+ worker._prepare_kv_and_wait.assert_called_once_with("req1", send_meta)
400
+
401
+ def test_process_send_load_for_consumer_loading(self):
402
+ """Tests process_send_load for a consumer that needs to load KV."""
403
+ self.vllm_config.kv_transfer_config.is_kv_producer = False
404
+ worker = tpu_connector.TPUConnectorWorker(self.vllm_config)
405
+ worker._maybe_build_kv_connection = MagicMock(return_value="conn")
406
+
407
+ meta = tpu_connector.TPUConnectorMetadata()
408
+ load_meta = tpu_connector.LoadMeta(uuid=1,
409
+ local_block_ids=[1],
410
+ remote_block_ids=[10],
411
+ remote_host="host",
412
+ remote_port=123)
413
+ meta.reqs_to_load = {"req1": load_meta}
414
+
415
+ worker.process_send_load(meta)
416
+
417
+ worker._maybe_build_kv_connection.assert_called_once_with(load_meta)
418
+ self.all_mocks[
419
+ "ThreadPoolExecutor"].return_value.submit.assert_called_once_with(
420
+ worker._pull_kv, "conn", load_meta)
421
+
422
+ def test_process_send_load_for_consumer_notifying(self):
423
+ """Tests process_send_load for a consumer that needs to notify."""
424
+ self.vllm_config.kv_transfer_config.is_kv_producer = False
425
+ worker = tpu_connector.TPUConnectorWorker(self.vllm_config)
426
+ worker._maybe_build_notif_socket = MagicMock(return_value="socket")
427
+ worker._notify_pull_done = MagicMock()
428
+
429
+ meta = tpu_connector.TPUConnectorMetadata()
430
+ load_meta = tpu_connector.LoadMeta(uuid=1,
431
+ local_block_ids=None,
432
+ remote_block_ids=None,
433
+ remote_host="host",
434
+ remote_port=123)
435
+ meta.reqs_to_load = {"req1": load_meta}
436
+
437
+ worker.process_send_load(meta)
438
+
439
+ worker._maybe_build_notif_socket.assert_called_once_with(load_meta)
440
+ worker._notify_pull_done.assert_called_once_with("socket", "req1")
441
+
442
+ def test_get_finished_recving(self):
443
+ """Tests get_finished for a request that has finished pulling."""
444
+ self.vllm_config.kv_transfer_config.is_kv_producer = False
445
+ worker = tpu_connector.TPUConnectorWorker(self.vllm_config)
446
+ worker.runner = MagicMock()
447
+ original_kv_caches = worker.runner.kv_caches
448
+
449
+ mock_future = MagicMock()
450
+ mock_future.done.return_value = True
451
+ mock_future.result.return_value = ('kv_data', 'indices')
452
+ worker.reqs_pulling = {'req1': mock_future}
453
+
454
+ done_sending, done_recving = worker.get_finished()
455
+
456
+ self.assertEqual(done_sending, set())
457
+ self.assertEqual(done_recving, {'req1'})
458
+ self.assertNotIn('req1', worker.reqs_pulling)
459
+ self.all_mocks['scatter_kv_slices'].assert_called_once_with(
460
+ original_kv_caches, 'kv_data', 'indices')
461
+
462
+ def test_get_finished_sending_expired(self):
463
+ """Tests get_finished for a request that has expired."""
464
+ self.vllm_config.kv_transfer_config.is_kv_producer = True
465
+ worker = tpu_connector.TPUConnectorWorker(self.vllm_config)
466
+
467
+ self.all_mocks['time'].perf_counter.return_value = 1000
468
+ worker.reqs_wait_pull = {'req1': ['kv_data', 900]}
469
+
470
+ done_sending, done_recving = worker.get_finished()
471
+
472
+ self.assertEqual(done_sending, {'req1'})
473
+ self.assertEqual(done_recving, set())
474
+ self.assertNotIn('req1', worker.reqs_wait_pull)
475
+
476
+
477
+ if __name__ == "__main__":
478
+ unittest.main()
tests/e2e/__init__.py ADDED
@@ -0,0 +1,13 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.