tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,1033 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from contextlib import nullcontext
16
+ from unittest.mock import MagicMock, patch
17
+
18
+ import numpy as np
19
+ import pytest
20
+
21
+ from tpu_inference.runner.tpu_runner import TPUModelRunner
22
+
23
+
24
+ class TestTPUJaxRunnerDPInputsLightweight:
25
+
26
+ def setup_method(self):
27
+ self.runner = MagicMock()
28
+
29
+ # Basic DP configuration
30
+ self.runner.dp_size = 2
31
+ self.runner.max_num_tokens = 64
32
+ self.runner.max_num_reqs = 8
33
+ self.runner.max_num_blocks_per_req = 8
34
+ self.runner.num_tokens_paddings = [16, 32, 64]
35
+
36
+ # Mock input batch - adjust num_reqs to match test data
37
+ self.runner.input_batch = MagicMock()
38
+ self.runner.input_batch.num_reqs = 2
39
+ self.runner.input_batch.req_ids = ["req1", "req2", "req3", "req4"]
40
+ self.runner.input_batch.req_id_to_index = {
41
+ "req1": 0,
42
+ "req2": 1,
43
+ "req3": 2,
44
+ "req4": 3
45
+ }
46
+ self.runner.input_batch.num_computed_tokens_cpu = np.array(
47
+ [10, 20, 5, 15])
48
+ self.runner.input_batch.token_ids_cpu = np.random.randint(
49
+ 0, 1000, (8, 64), dtype=np.int32)
50
+
51
+ # Mock block table
52
+ mock_block_table = MagicMock()
53
+ mock_block_table.get_cpu_tensor.return_value = np.arange(32).reshape(
54
+ 4, 8)
55
+ self.runner.input_batch.block_table = [mock_block_table]
56
+
57
+ # Initialize CPU arrays that the method modifies
58
+ self.runner.input_ids_cpu = np.zeros(64, dtype=np.int32)
59
+ self.runner.positions_cpu = np.zeros(64, dtype=np.int32)
60
+ self.runner.query_start_loc_cpu = np.zeros(10, dtype=np.int32)
61
+ self.runner.seq_lens_cpu = np.zeros(8, dtype=np.int32)
62
+ self.runner.logits_indices_cpu = np.zeros(8, dtype=np.int32)
63
+ self.runner.block_tables_cpu = [np.zeros((8, 8), dtype=np.int32)]
64
+ self.runner.arange_cpu = np.arange(64, dtype=np.int64)
65
+
66
+ # mock kv cache group
67
+ mock_kv_cache_config = MagicMock()
68
+ mock_kv_cache_group = MagicMock()
69
+ mock_kv_cache_config.kv_cache_groups = [mock_kv_cache_group]
70
+ self.runner.kv_cache_config = mock_kv_cache_config
71
+ self.runner.use_hybrid_kvcache = False
72
+
73
+ # Mock scheduler config for async scheduling
74
+ self.runner.scheduler_config = MagicMock()
75
+ self.runner.scheduler_config.async_scheduling = False # Default to False for most tests
76
+ self.runner._pre_async_results = None # Default to None for most tests
77
+
78
+ # Bind the actual methods to our mock
79
+ self.runner._prepare_inputs_dp = TPUModelRunner._prepare_inputs_dp.__get__(
80
+ self.runner)
81
+ self.runner._prepare_dp_input_metadata = TPUModelRunner._prepare_dp_input_metadata.__get__(
82
+ self.runner)
83
+ self.runner._prepare_async_token_substitution_indices_dp = TPUModelRunner._prepare_async_token_substitution_indices_dp.__get__(
84
+ self.runner)
85
+
86
+ def _create_mock_scheduler_output(self,
87
+ num_scheduled_tokens_dict,
88
+ assigned_dp_ranks,
89
+ scheduled_spec_decode_tokens=None):
90
+ """Create a minimal mock scheduler output."""
91
+ mock_output = MagicMock()
92
+ mock_output.num_scheduled_tokens = num_scheduled_tokens_dict
93
+ mock_output.assigned_dp_rank = assigned_dp_ranks
94
+ mock_output.total_num_scheduled_tokens = sum(
95
+ num_scheduled_tokens_dict.values())
96
+ mock_output.scheduled_spec_decode_tokens = scheduled_spec_decode_tokens or {}
97
+ mock_output.grammar_bitmask = None
98
+ return mock_output
99
+
100
+ @patch('tpu_inference.runner.tpu_runner.NamedSharding')
101
+ @patch('tpu_inference.runner.tpu_runner.runner_utils')
102
+ @patch('tpu_inference.runner.tpu_runner.device_array',
103
+ side_effect=lambda mesh, tensors, **kwargs: tensors)
104
+ @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
105
+ def test_prepare_inputs_dp_basic_functionality(self,
106
+ mock_sampling_metadata,
107
+ mock_device_array,
108
+ mock_runner_utils,
109
+ mock_named_sharding):
110
+ """Test basic functionality of _prepare_inputs_dp."""
111
+ # Mock utility functions
112
+ mock_runner_utils.get_padded_token_len.return_value = 16
113
+ mock_sampling_metadata.from_input_batch.return_value = MagicMock()
114
+ mock_named_sharding.return_value = MagicMock()
115
+
116
+ # Create test data - only use req1 and req2 to match num_reqs=2
117
+ num_scheduled_tokens = {"req1": 5, "req2": 3}
118
+ assigned_dp_ranks = {"req1": 0, "req2": 1}
119
+ scheduler_output = self._create_mock_scheduler_output(
120
+ num_scheduled_tokens, assigned_dp_ranks)
121
+
122
+ # Execute the method
123
+ result = self.runner._prepare_inputs_dp(scheduler_output)
124
+
125
+ # Basic assertions
126
+ assert len(result) == 8
127
+ input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
128
+
129
+ # Verify utility functions were called
130
+ mock_runner_utils.get_padded_token_len.assert_called()
131
+
132
+ def test_prepare_inputs_dp_error_conditions(self):
133
+ """Test error handling in DP input preparation."""
134
+ # Test with zero scheduled tokens - should fail assertion: total_num_scheduled_tokens > 0
135
+ scheduler_output = self._create_mock_scheduler_output({}, {})
136
+ scheduler_output.total_num_scheduled_tokens = 0
137
+
138
+ with pytest.raises(AssertionError):
139
+ self.runner._prepare_inputs_dp(scheduler_output)
140
+
141
+ # Test with zero requests - should fail assertion: num_reqs > 0
142
+ self.runner.input_batch.num_reqs = 0
143
+ scheduler_output = self._create_mock_scheduler_output({"req1": 5},
144
+ {"req1": 0})
145
+
146
+ with pytest.raises(AssertionError):
147
+ self.runner._prepare_inputs_dp(scheduler_output)
148
+
149
+ def test_prepare_dp_input_metadata(self):
150
+ num_scheduled_tokens = {"req1": 10, "req2": 5, "req3": 8, "req4": 3}
151
+ assigned_dp_ranks = {"req1": 0, "req2": 0, "req3": 1, "req4": 1}
152
+
153
+ self.runner.input_batch.num_reqs = 4
154
+ self.runner.input_batch.req_ids = ["req1", "req2", "req3", "req4"]
155
+ self.runner.max_num_reqs = 8
156
+
157
+ scheduler_output = self._create_mock_scheduler_output(
158
+ num_scheduled_tokens, assigned_dp_ranks)
159
+
160
+ with patch('tpu_inference.runner.tpu_runner.runner_utils'
161
+ ) as mock_runner_utils:
162
+ mock_runner_utils.get_padded_token_len.side_effect = lambda paddings_list, val: 16 if val <= 15 else 32 # Padded tokens per DP rank
163
+
164
+ result = self.runner._prepare_dp_input_metadata(scheduler_output)
165
+
166
+ (req_ids_dp, req_indices_dp, num_scheduled_tokens_per_dp_rank,
167
+ scheduled_tokens_per_dp_rank, num_req_per_dp_rank,
168
+ padded_num_scheduled_tokens_per_dp_rank, padded_num_reqs,
169
+ padded_total_num_scheduled_tokens, padded_num_reqs_per_dp_rank,
170
+ logits_indices_selector, max_num_reqs_per_dp_rank) = result
171
+
172
+ # 1. req_ids_dp: Dictionary mapping DP rank to request IDs
173
+ assert isinstance(req_ids_dp, dict)
174
+ assert req_ids_dp[0] == ["req1", "req2"]
175
+ assert req_ids_dp[1] == ["req3", "req4"]
176
+
177
+ # 2. req_indices_dp: Dictionary mapping DP rank to request indices
178
+ assert isinstance(req_indices_dp, dict)
179
+ assert req_indices_dp[0] == [0, 1] # indices of req1, req2
180
+ assert req_indices_dp[1] == [2, 3] # indices of req3, req4
181
+
182
+ # 3. num_scheduled_tokens_per_dp_rank: Total tokens per DP rank
183
+ assert isinstance(num_scheduled_tokens_per_dp_rank, dict)
184
+ assert num_scheduled_tokens_per_dp_rank[0] == 15 # 10 + 5
185
+ assert num_scheduled_tokens_per_dp_rank[1] == 11 # 8 + 3
186
+
187
+ # 4. scheduled_tokens_per_dp_rank: List of token counts per request per DP rank
188
+ assert isinstance(scheduled_tokens_per_dp_rank, dict)
189
+ assert scheduled_tokens_per_dp_rank[0] == [10,
190
+ 5] # req1=10, req2=5
191
+ assert scheduled_tokens_per_dp_rank[1] == [8, 3] # req3=8, req4=3
192
+
193
+ # 5. num_req_per_dp_rank: Number of requests per DP rank
194
+ assert isinstance(num_req_per_dp_rank, dict)
195
+ assert num_req_per_dp_rank[0] == 2
196
+ assert num_req_per_dp_rank[1] == 2
197
+
198
+ # 6. padded_num_scheduled_tokens_per_dp_rank: Padded token count per rank
199
+ assert padded_num_scheduled_tokens_per_dp_rank == 16
200
+
201
+ # 7. padded_num_reqs: Total padded requests across all ranks
202
+ assert padded_num_reqs == 32 # 2 DP ranks * 16 padded reqs per rank
203
+
204
+ # 8. padded_total_num_scheduled_tokens: Total padded tokens across all ranks
205
+ assert padded_total_num_scheduled_tokens == 32 # 2 DP ranks * 16 padded tokens per rank
206
+
207
+ # 9. padded_num_reqs_per_dp_rank: Padded requests per DP rank
208
+ assert padded_num_reqs_per_dp_rank == 16
209
+
210
+ # 10. logits_indices_selector: Array to map back to original request order
211
+ assert isinstance(logits_indices_selector, np.ndarray)
212
+ assert len(logits_indices_selector) == 4 # One for each request
213
+ # Should map distributed positions back to original order
214
+ expected_selector = np.array([0, 1, 16, 17])
215
+ np.testing.assert_array_equal(logits_indices_selector,
216
+ expected_selector)
217
+
218
+ # 11. max_num_reqs_per_dp_rank: Maximum requests per DP rank
219
+ assert max_num_reqs_per_dp_rank == 4 # max_num_reqs (8) // dp_size (2)
220
+
221
+ def test_prepare_dp_input_metadata_empty_rank(self):
222
+ """Test metadata preparation with one empty DP rank"""
223
+ # Create test data where all requests go to rank 0, leaving rank 1 empty
224
+ num_scheduled_tokens = {"req1": 10, "req2": 5}
225
+ assigned_dp_ranks = {"req1": 0, "req2": 0}
226
+
227
+ self.runner.input_batch.num_reqs = 2
228
+ self.runner.input_batch.req_ids = ["req1", "req2"]
229
+ self.runner.max_num_reqs = 8
230
+
231
+ scheduler_output = self._create_mock_scheduler_output(
232
+ num_scheduled_tokens, assigned_dp_ranks)
233
+
234
+ with patch('tpu_inference.runner.tpu_runner.runner_utils'
235
+ ) as mock_runner_utils:
236
+ mock_runner_utils.get_padded_token_len.side_effect = lambda paddings_list, val: 16 if val <= 15 else 32
237
+
238
+ result = self.runner._prepare_dp_input_metadata(scheduler_output)
239
+
240
+ (req_ids_dp, req_indices_dp, num_scheduled_tokens_per_dp_rank,
241
+ scheduled_tokens_per_dp_rank, num_req_per_dp_rank,
242
+ padded_num_scheduled_tokens_per_dp_rank, padded_num_reqs,
243
+ padded_total_num_scheduled_tokens, padded_num_reqs_per_dp_rank,
244
+ logits_indices_selector, max_num_reqs_per_dp_rank) = result
245
+
246
+ # 1. req_ids_dp
247
+ assert isinstance(req_ids_dp, dict)
248
+ assert req_ids_dp[0] == ["req1", "req2"]
249
+ assert req_ids_dp[1] == [] # Empty rank
250
+
251
+ # 2. req_indices_dp
252
+ assert isinstance(req_indices_dp, dict)
253
+ assert req_indices_dp[0] == [0, 1] # req1, req2 indices
254
+ assert req_indices_dp[1] == [] # Empty rank
255
+
256
+ # 3. num_scheduled_tokens_per_dp_rank
257
+ assert isinstance(num_scheduled_tokens_per_dp_rank, dict)
258
+ assert num_scheduled_tokens_per_dp_rank[0] == 15 # 10 + 5
259
+ assert num_scheduled_tokens_per_dp_rank[1] == 0 # Empty rank
260
+
261
+ # 4. scheduled_tokens_per_dp_rank
262
+ assert isinstance(scheduled_tokens_per_dp_rank, dict)
263
+ assert scheduled_tokens_per_dp_rank[0] == [10,
264
+ 5] # req1=10, req2=5
265
+ assert scheduled_tokens_per_dp_rank[1] == [] # Empty rank
266
+
267
+ # 5. num_req_per_dp_rank
268
+ assert isinstance(num_req_per_dp_rank, dict)
269
+ assert num_req_per_dp_rank[0] == 2 # Both requests on rank 0
270
+ assert num_req_per_dp_rank[1] == 0 # No requests on rank 1
271
+
272
+ # 6. padded_num_scheduled_tokens_per_dp_rank
273
+ assert padded_num_scheduled_tokens_per_dp_rank == 16
274
+
275
+ # 7. padded_num_reqs
276
+ assert padded_num_reqs == 32 # 2 DP ranks * 16 padded reqs per rank
277
+
278
+ # 8. padded_total_num_scheduled_tokens
279
+ assert padded_total_num_scheduled_tokens == 32 # 2 DP ranks * 16 padded tokens per rank
280
+
281
+ # 10. padded_num_reqs_per_dp_rank: Padded requests per DP rank
282
+ assert padded_num_reqs_per_dp_rank == 16
283
+
284
+ # 11. logits_indices_selector: Should preserve original order since no reordering needed
285
+ assert isinstance(logits_indices_selector, np.ndarray)
286
+ assert len(logits_indices_selector) == 2
287
+ # Both requests on DP rank 0, positions 0 and 1
288
+ expected_selector = np.array([0, 1])
289
+ np.testing.assert_array_equal(logits_indices_selector,
290
+ expected_selector)
291
+
292
+ # 12. max_num_reqs_per_dp_rank: Maximum requests per DP rank
293
+ assert max_num_reqs_per_dp_rank == 4 # max_num_reqs (8) // dp_size (2)
294
+
295
+ def test_prepare_dp_input_metadata_logits_indices_selector_ordering(self):
296
+ """Test logits_indices_selector with mixed DP rank assignment."""
297
+ # Create requests with mixed assignment to test reordering
298
+ num_scheduled_tokens = {"req1": 4, "req2": 6, "req3": 2}
299
+ assigned_dp_ranks = {
300
+ "req1": 1,
301
+ "req2": 0,
302
+ "req3": 1
303
+ } # req2 on rank 0, req1&req3 on rank 1
304
+
305
+ self.runner.input_batch.num_reqs = 3
306
+ self.runner.input_batch.req_ids = ["req1", "req2", "req3"]
307
+
308
+ scheduler_output = self._create_mock_scheduler_output(
309
+ num_scheduled_tokens, assigned_dp_ranks)
310
+
311
+ with patch('tpu_inference.runner.tpu_runner.runner_utils'
312
+ ) as mock_runner_utils:
313
+ mock_runner_utils.get_padded_token_len.side_effect = lambda paddings_list, val: 8 if val <= 6 else 16
314
+
315
+ result = self.runner._prepare_dp_input_metadata(scheduler_output)
316
+
317
+ (req_ids_dp, req_indices_dp, _, _, _, _, _, _, _,
318
+ logits_indices_selector, _) = result
319
+
320
+ # Verify request distribution
321
+ assert req_ids_dp[0] == ["req2"] # rank 0: req2 (index 1)
322
+ assert req_ids_dp[1] == [
323
+ "req1", "req3"
324
+ ] # rank 1: req1 (index 0), req3 (index 2)
325
+
326
+ assert req_indices_dp[0] == [1] # req2 has original index 1
327
+ assert req_indices_dp[1] == [
328
+ 0, 2
329
+ ] # req1 has index 0, req3 has index 2
330
+
331
+ # The logits_indices_selector should map the DP-distributed positions back to original order
332
+
333
+ assert isinstance(logits_indices_selector, np.ndarray)
334
+ assert len(logits_indices_selector) == 3
335
+
336
+ expected_positions = np.array([8, 0, 9])
337
+ np.testing.assert_array_equal(logits_indices_selector,
338
+ expected_positions)
339
+
340
+ @patch('tpu_inference.runner.tpu_runner.NamedSharding')
341
+ @patch('tpu_inference.runner.tpu_runner.runner_utils')
342
+ @patch('tpu_inference.runner.tpu_runner.device_array',
343
+ side_effect=lambda mesh, tensors, **kwargs: tensors)
344
+ @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
345
+ def test_prepare_inputs_dp_verify_content_balanced(self,
346
+ mock_sampling_metadata,
347
+ mock_device_array,
348
+ mock_runner_utils,
349
+ mock_named_sharding):
350
+ """Test _prepare_inputs_dp with content verification for balanced distribution."""
351
+
352
+ # Setup mocking with specific behavior for tokens vs requests
353
+ def mock_get_padded_token_len(paddings_list, val):
354
+ # For tokens: 8 if val <= 3 else 16
355
+ # For requests: 4 if val <= 1 else 8
356
+ if val <= 1:
357
+ return 4 # For request padding
358
+ elif val <= 3:
359
+ return 8 # For token padding
360
+ else:
361
+ return 16
362
+
363
+ mock_runner_utils.get_padded_token_len.side_effect = mock_get_padded_token_len
364
+ mock_sampling_instance = MagicMock()
365
+ mock_sampling_metadata.from_input_batch.return_value = mock_sampling_instance
366
+ mock_named_sharding.return_value = MagicMock()
367
+
368
+ # Setup deterministic test data
369
+ num_scheduled_tokens = {"req1": 2, "req2": 3}
370
+ assigned_dp_ranks = {"req1": 0, "req2": 1}
371
+
372
+ self.runner.input_batch.num_reqs = 2
373
+ self.runner.input_batch.req_ids = ["req1", "req2"]
374
+ self.runner.input_batch.num_computed_tokens_cpu = np.array(
375
+ [5, 6]) # Starting positions
376
+
377
+ # Setup known token sequences for verification
378
+ self.runner.input_batch.token_ids_cpu = np.zeros((8, 64),
379
+ dtype=np.int32)
380
+ # req1: [1001, 1002, 1003, ...]
381
+ # req2: [2001, 2002, 2003, ...]
382
+ for i in range(2):
383
+ start_val = (i + 1) * 1000 + 1
384
+ for j in range(64):
385
+ self.runner.input_batch.token_ids_cpu[i, j] = start_val + j
386
+
387
+ scheduler_output = self._create_mock_scheduler_output(
388
+ num_scheduled_tokens, assigned_dp_ranks)
389
+
390
+ # Setup additional required attributes
391
+ self.runner.uses_mrope = False
392
+ self.runner.phase_based_profiler = None
393
+ self.runner.lora_config = None
394
+ self.runner.mesh = MagicMock()
395
+ self.runner.data_parallel_sharding = MagicMock()
396
+ self.runner.data_parallel_attn_sharding = MagicMock()
397
+ self.runner.mm_manager = MagicMock()
398
+ self.runner.speculative_decoding_manager = MagicMock()
399
+ self.runner.lora_utils = MagicMock()
400
+ # self.runner.mrope_positions_cpu = np.zeros((3, 64), dtype=np.int64)
401
+
402
+ # Execute the method
403
+ result = self.runner._prepare_inputs_dp(scheduler_output)
404
+ input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
405
+ # 1. Verify input_ids content
406
+ expected_input_ids = np.zeros(16, dtype=np.int32)
407
+ expected_input_ids[:2] = [1006, 1007]
408
+ expected_input_ids[8:11] = [2007, 2008, 2009]
409
+ assert np.array_equal(input_ids, expected_input_ids)
410
+
411
+ # 2. Verify attention_metadata positions content
412
+ expected_positions = np.zeros(16, dtype=np.int32)
413
+ expected_positions[:2] = [5, 6] # req1 positions
414
+ expected_positions[8:11] = [6, 7, 8]
415
+ assert np.array_equal(attention_metadata.input_positions,
416
+ expected_positions)
417
+
418
+ # 3. Verify query_start_loc content
419
+ query_start_loc = attention_metadata.query_start_loc_cpu
420
+ max_num_reqs_per_dp = self.runner.max_num_reqs // 2
421
+ expected_query_start = np.zeros(self.runner.max_num_reqs + 2,
422
+ dtype=np.int32)
423
+ # DP rank 0: cumsum([2]) = [2] at positions [1:2] → [0, 2, 1, 1, 1]
424
+ expected_query_start[1] = 2 # req1 has 2 tokens
425
+ expected_query_start[2:max_num_reqs_per_dp + 1] = 1
426
+ # DP rank 1: cumsum([3]) = [3] at positions [6:7] → [0, 3, 1, 1, 1]
427
+ expected_query_start[max_num_reqs_per_dp + 2] = 3 # req2 has 3 tokens
428
+ expected_query_start[max_num_reqs_per_dp + 3:] = 1
429
+ assert np.array_equal(query_start_loc, expected_query_start)
430
+
431
+ # 4. Verify seq_lens content
432
+ seq_lens = attention_metadata.seq_lens_cpu
433
+ # Should be computed_tokens + scheduled_tokens for each request
434
+ # DP rank 0: req1 at position 0, DP rank 1: req2 at position 4
435
+ expected_seq_lens = np.array([7, 0, 0, 0, 9, 0, 0,
436
+ 0]) # req1: 5+2=7, req2: 6+3=9
437
+ assert np.array_equal(seq_lens, expected_seq_lens)
438
+
439
+ # 5. Verify request_distribution content
440
+ expected_distribution = np.array([[0, 0, 1], [0, 0, 1]]).flatten()
441
+ np.testing.assert_array_equal(attention_metadata.request_distribution,
442
+ expected_distribution)
443
+
444
+ # 6. Verify logits_indices content
445
+ assert len(logits_indices) == 8 # padded_num_reqs
446
+ expected_logits = np.full(8, -1, dtype=np.int32)
447
+ expected_logits[0] = 1 # req1 last token position (2-1)
448
+ expected_logits[
449
+ 4] = 2 # req2 last token position (3-1) at DP rank 1 offset (4*1)
450
+ assert np.array_equal(logits_indices, expected_logits)
451
+
452
+ # 7. Verify logits_indices_selector
453
+ assert len(logits_indices_selector) == 2
454
+ assert np.array_equal(logits_indices_selector, np.array([0, 4]))
455
+
456
+ @patch('tpu_inference.runner.tpu_runner.NamedSharding')
457
+ @patch('tpu_inference.runner.tpu_runner.runner_utils')
458
+ @patch('tpu_inference.runner.tpu_runner.device_array',
459
+ side_effect=lambda mesh, tensors, **kwargs: tensors)
460
+ @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
461
+ def test_prepare_inputs_dp_verify_content_empty_rank(
462
+ self, mock_sampling_metadata, mock_device_array, mock_runner_utils,
463
+ mock_named_sharding):
464
+ """Test _prepare_inputs_dp with detailed content verification for empty rank case."""
465
+
466
+ # Setup mocking
467
+ def mock_get_padded_token_len(paddings_list, val):
468
+ if val <= 2:
469
+ return 4 # For request padding (max 2 requests)
470
+ elif val <= 5:
471
+ return 8 # For token padding
472
+ else:
473
+ return 16
474
+
475
+ mock_runner_utils.get_padded_token_len.side_effect = mock_get_padded_token_len
476
+ mock_sampling_instance = MagicMock()
477
+ mock_sampling_metadata.from_input_batch.return_value = mock_sampling_instance
478
+ mock_named_sharding.return_value = MagicMock()
479
+
480
+ # Setup test data with all requests on rank 0 (empty rank 1)
481
+ num_scheduled_tokens = {"req1": 3, "req2": 2}
482
+ assigned_dp_ranks = {
483
+ "req1": 0,
484
+ "req2": 0
485
+ } # Both on rank 0, rank 1 empty
486
+
487
+ self.runner.input_batch.num_reqs = 2
488
+ self.runner.input_batch.req_ids = ["req1", "req2"]
489
+ self.runner.input_batch.num_computed_tokens_cpu = np.array(
490
+ [4, 6]) # Starting positions
491
+
492
+ # Setup deterministic token sequences for verification
493
+ self.runner.input_batch.token_ids_cpu = np.zeros((8, 64),
494
+ dtype=np.int32)
495
+ # req1: [5001, 5002, 5003, ...] starting at position 4
496
+ # req2: [6001, 6002, 6003, ...] starting at position 6
497
+ for i in range(2):
498
+ start_val = (i + 5) * 1000 + 1 # 5001, 6001
499
+ for j in range(64):
500
+ self.runner.input_batch.token_ids_cpu[i, j] = start_val + j
501
+
502
+ scheduler_output = self._create_mock_scheduler_output(
503
+ num_scheduled_tokens, assigned_dp_ranks)
504
+
505
+ # Setup required attributes
506
+ self.runner.uses_mrope = False
507
+ self.runner.phase_based_profiler = None
508
+ self.runner.lora_config = None
509
+ self.runner.mesh = MagicMock()
510
+ self.runner.data_parallel_sharding = MagicMock()
511
+ self.runner.data_parallel_attn_sharding = MagicMock()
512
+ self.runner.mm_manager = MagicMock()
513
+ self.runner.speculative_decoding_manager = MagicMock()
514
+ self.runner.lora_utils = MagicMock()
515
+
516
+ # Execute the method
517
+ result = self.runner._prepare_inputs_dp(scheduler_output)
518
+ input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
519
+
520
+ # 1. Verify input_ids
521
+ expected_input_ids = np.zeros(16, dtype=np.int32)
522
+ # Rank 0
523
+ expected_input_ids[:5] = [5005, 5006, 5007, 6007, 6008]
524
+ # Rank 1 (positions 8-15) should remain zeros
525
+ assert np.array_equal(input_ids, expected_input_ids)
526
+
527
+ # 2. Verify attention_metadata
528
+ expected_positions = np.zeros(16, dtype=np.int32)
529
+ expected_positions[:3] = [4, 5, 6] # req1 positions: 4 + [0, 1, 2]
530
+ expected_positions[3:5] = [6, 7] # req2 positions: 6 + [0, 1]
531
+ # Rank 1 positions (8-15) remain zeros
532
+ assert np.array_equal(attention_metadata.input_positions,
533
+ expected_positions)
534
+
535
+ # 3. Verify query_start_loc
536
+ query_start_loc = attention_metadata.query_start_loc_cpu
537
+ max_num_reqs_per_dp = self.runner.max_num_reqs // 2 # 4
538
+ expected_query_start = np.zeros(self.runner.max_num_reqs + 2,
539
+ dtype=np.int32)
540
+ # Rank 0: req1 (3 tokens), req2 (2 tokens)
541
+ expected_query_start[1] = 3 # req1 has 3 tokens
542
+ expected_query_start[2] = 5 # cumulative: 3 + 2 = 5
543
+ expected_query_start[3:max_num_reqs_per_dp + 1] = 1 # padding
544
+ # Rank 1: empty (all zeros)
545
+ expected_query_start[max_num_reqs_per_dp +
546
+ 1:] = 0 # Empty rank sets to 0
547
+ assert np.array_equal(query_start_loc, expected_query_start)
548
+
549
+ # 4. Verify seq_lens
550
+ seq_lens = attention_metadata.seq_lens_cpu
551
+ expected_seq_lens = np.zeros(8, dtype=np.int32)
552
+ # Rank 0: req1 (4+3=7), req2 (6+2=8), then padding
553
+ expected_seq_lens[
554
+ 0] = 7 # req1: computed_tokens(4) + scheduled_tokens(3)
555
+ expected_seq_lens[
556
+ 1] = 8 # req2: computed_tokens(6) + scheduled_tokens(2)
557
+ # Rank 1: all zeros
558
+ assert np.array_equal(seq_lens, expected_seq_lens)
559
+
560
+ # 5. Verify request_distribution
561
+ expected_distribution = np.array([[0, 0, 2], [0, 0, 0]]).flatten()
562
+ np.testing.assert_array_equal(attention_metadata.request_distribution,
563
+ expected_distribution)
564
+
565
+ # 6. Verify logits_indices
566
+ assert len(
567
+ logits_indices) == 8 # padded_num_reqs (8 in this case, not 16)
568
+ # Rank 0: req1 ends at pos 2, req2 ends at pos 4
569
+ # Rank 1: empty, so -1 padding
570
+ expected_logits = np.full(8, -1, dtype=np.int32)
571
+ expected_logits[0] = 2 # req1 ends at position 2 (3-1)
572
+ expected_logits[1] = 4 # req2 ends at position 4 (5-1)
573
+ assert np.array_equal(logits_indices, expected_logits)
574
+
575
+ # 7. Verify logits_indices_selector
576
+ assert len(logits_indices_selector) == 2
577
+ expected_selector = np.array([0, 1])
578
+ np.testing.assert_array_equal(logits_indices_selector,
579
+ expected_selector)
580
+
581
+ @patch('tpu_inference.runner.tpu_runner.NamedSharding')
582
+ @patch('tpu_inference.runner.tpu_runner.runner_utils')
583
+ @patch('tpu_inference.runner.tpu_runner.device_array',
584
+ side_effect=lambda mesh, tensors, **kwargs: tensors)
585
+ @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
586
+ def test_prepare_inputs_dp_with_decode_requests(self,
587
+ mock_sampling_metadata,
588
+ mock_device_array,
589
+ mock_runner_utils,
590
+ mock_named_sharding):
591
+ """Test _prepare_inputs_dp with decode requests (1 token each) to verify request_distribution."""
592
+
593
+ # Setup mocking
594
+ def mock_get_padded_token_len(paddings_list, val):
595
+ if val <= 2:
596
+ return 4 # For request padding
597
+ elif val <= 4:
598
+ return 8 # For token padding
599
+ else:
600
+ return 16
601
+
602
+ mock_runner_utils.get_padded_token_len.side_effect = mock_get_padded_token_len
603
+ mock_sampling_instance = MagicMock()
604
+ mock_sampling_metadata.from_input_batch.return_value = mock_sampling_instance
605
+ mock_named_sharding.return_value = MagicMock()
606
+
607
+ # Setup test data with decode requests (1 token) and prefill requests (>1 token)
608
+ # req1: decode (1 token), req2: decode (1 token), req3: prefill (3 tokens), req4: decode (1 token)
609
+ num_scheduled_tokens = {"req1": 1, "req2": 1, "req3": 3, "req4": 1}
610
+ assigned_dp_ranks = {"req1": 0, "req2": 0, "req3": 1, "req4": 1}
611
+
612
+ self.runner.input_batch.num_reqs = 4
613
+ self.runner.input_batch.req_ids = ["req1", "req2", "req3", "req4"]
614
+ self.runner.input_batch.num_computed_tokens_cpu = np.array(
615
+ [5, 6, 7, 8])
616
+ self.runner.input_batch.token_ids_cpu = np.zeros((8, 64),
617
+ dtype=np.int32)
618
+
619
+ scheduler_output = self._create_mock_scheduler_output(
620
+ num_scheduled_tokens, assigned_dp_ranks)
621
+
622
+ # Setup required attributes
623
+ self.runner.uses_mrope = False
624
+ self.runner.phase_based_profiler = None
625
+ self.runner.lora_config = None
626
+ self.runner.mesh = MagicMock()
627
+ self.runner.data_parallel_sharding = MagicMock()
628
+ self.runner.data_parallel_attn_sharding = MagicMock()
629
+ self.runner.mm_manager = MagicMock()
630
+ self.runner.speculative_decoding_manager = MagicMock()
631
+ self.runner.lora_utils = MagicMock()
632
+
633
+ # Execute the method
634
+ result = self.runner._prepare_inputs_dp(scheduler_output)
635
+ input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
636
+
637
+ # Verify request_distribution
638
+ # DP rank 0: req1 (decode), req2 (decode) -> [2, 2, 2]
639
+ # DP rank 1: req3 (prefill), req4 (decode) -> [1, 1, 2]
640
+ expected_distribution = np.array([[2, 2, 2], [1, 1, 2]]).flatten()
641
+ np.testing.assert_array_equal(attention_metadata.request_distribution,
642
+ expected_distribution)
643
+
644
+ @patch('tpu_inference.runner.tpu_runner.NamedSharding')
645
+ @patch('tpu_inference.runner.tpu_runner.runner_utils')
646
+ @patch('tpu_inference.runner.tpu_runner.device_array',
647
+ side_effect=lambda mesh, tensors, **kwargs: tensors)
648
+ @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
649
+ def test_prepare_inputs_dp_all_decode_requests(self,
650
+ mock_sampling_metadata,
651
+ mock_device_array,
652
+ mock_runner_utils,
653
+ mock_named_sharding):
654
+ """Test _prepare_inputs_dp with all decode requests."""
655
+
656
+ # Setup mocking
657
+ def mock_get_padded_token_len(paddings_list, val):
658
+ if val <= 2:
659
+ return 4
660
+ elif val <= 4:
661
+ return 8
662
+ else:
663
+ return 16
664
+
665
+ mock_runner_utils.get_padded_token_len.side_effect = mock_get_padded_token_len
666
+ mock_sampling_instance = MagicMock()
667
+ mock_sampling_metadata.from_input_batch.return_value = mock_sampling_instance
668
+ mock_named_sharding.return_value = MagicMock()
669
+
670
+ # All requests are decode (1 token each)
671
+ num_scheduled_tokens = {"req1": 1, "req2": 1}
672
+ assigned_dp_ranks = {"req1": 0, "req2": 1}
673
+
674
+ self.runner.input_batch.num_reqs = 2
675
+ self.runner.input_batch.req_ids = ["req1", "req2"]
676
+ self.runner.input_batch.num_computed_tokens_cpu = np.array([5, 6])
677
+ self.runner.input_batch.token_ids_cpu = np.zeros((8, 64),
678
+ dtype=np.int32)
679
+
680
+ scheduler_output = self._create_mock_scheduler_output(
681
+ num_scheduled_tokens, assigned_dp_ranks)
682
+
683
+ # Setup required attributes
684
+ self.runner.uses_mrope = False
685
+ self.runner.phase_based_profiler = None
686
+ self.runner.lora_config = None
687
+ self.runner.mesh = MagicMock()
688
+ self.runner.data_parallel_sharding = MagicMock()
689
+ self.runner.data_parallel_attn_sharding = MagicMock()
690
+ self.runner.mm_manager = MagicMock()
691
+ self.runner.speculative_decoding_manager = MagicMock()
692
+ self.runner.lora_utils = MagicMock()
693
+
694
+ # Execute the method
695
+ result = self.runner._prepare_inputs_dp(scheduler_output)
696
+ input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
697
+
698
+ # Verify request_distribution
699
+ # Both ranks have only decode requests
700
+ # DP rank 0: req1 (decode) -> [1, 1, 1]
701
+ # DP rank 1: req2 (decode) -> [1, 1, 1]
702
+ expected_distribution = np.array([[1, 1, 1], [1, 1, 1]]).flatten()
703
+ np.testing.assert_array_equal(attention_metadata.request_distribution,
704
+ expected_distribution)
705
+
706
+ @patch('tpu_inference.runner.tpu_runner.NamedSharding')
707
+ @patch('tpu_inference.runner.tpu_runner.runner_utils')
708
+ @patch('tpu_inference.runner.tpu_runner.device_array',
709
+ side_effect=lambda mesh, tensors, **kwargs: tensors)
710
+ @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
711
+ def test_prepare_async_token_substitution_indices_dp(
712
+ self, mock_sampling_metadata, mock_device_array, mock_runner_utils,
713
+ mock_named_sharding):
714
+
715
+ # Setup test data
716
+ req_ids_dp = {0: ["req1", "req2"], 1: ["req3"]}
717
+ scheduled_tokens_per_dp_rank = {0: [3, 2], 1: [4]}
718
+ padded_num_scheduled_tokens_per_dp_rank = 8
719
+ dp_size = 2
720
+
721
+ # Setup _pre_async_results with placeholder mapping
722
+ self.runner._pre_async_results = MagicMock()
723
+ self.runner._pre_async_results.placeholder_req_id_to_index = {
724
+ "req1": 0,
725
+ "req3": 2
726
+ } # req2 is not a placeholder
727
+
728
+ # Call the method
729
+ result = self.runner._prepare_async_token_substitution_indices_dp(
730
+ req_ids_dp, scheduled_tokens_per_dp_rank,
731
+ padded_num_scheduled_tokens_per_dp_rank, dp_size)
732
+
733
+ token_in_tpu_cur_input_indices_dp, token_in_tpu_pre_next_tokens_indices_dp = result
734
+
735
+ # Verify DP rank 0
736
+ # req1: token_offset=0, acc_cur_len starts at 0, after 3 tokens: 3, so last token at 2
737
+ # req2: not a placeholder, should be skipped
738
+ assert token_in_tpu_cur_input_indices_dp[0] == [2]
739
+ assert token_in_tpu_pre_next_tokens_indices_dp[0] == [0]
740
+
741
+ # Verify DP rank 1
742
+ # req3: token_offset=8, acc_cur_len starts at 8, after 4 tokens: 12, so last token at 11
743
+ assert token_in_tpu_cur_input_indices_dp[1] == [11]
744
+ assert token_in_tpu_pre_next_tokens_indices_dp[1] == [2]
745
+
746
+ @patch('tpu_inference.runner.tpu_runner.NamedSharding')
747
+ @patch('tpu_inference.runner.tpu_runner.runner_utils')
748
+ @patch('tpu_inference.runner.tpu_runner.device_array',
749
+ side_effect=lambda mesh, tensors, **kwargs: tensors)
750
+ @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
751
+ def test_prepare_async_token_substitution_indices_dp_no_placeholders(
752
+ self, mock_sampling_metadata, mock_device_array, mock_runner_utils,
753
+ mock_named_sharding):
754
+ """Test when no requests are placeholders."""
755
+
756
+ req_ids_dp = {0: ["req1", "req2"], 1: ["req3"]}
757
+ scheduled_tokens_per_dp_rank = {0: [3, 2], 1: [4]}
758
+ padded_num_scheduled_tokens_per_dp_rank = 8
759
+ dp_size = 2
760
+
761
+ # No placeholders
762
+ self.runner._pre_async_results = MagicMock()
763
+ self.runner._pre_async_results.placeholder_req_id_to_index = {}
764
+
765
+ result = self.runner._prepare_async_token_substitution_indices_dp(
766
+ req_ids_dp, scheduled_tokens_per_dp_rank,
767
+ padded_num_scheduled_tokens_per_dp_rank, dp_size)
768
+
769
+ token_in_tpu_cur_input_indices_dp, token_in_tpu_pre_next_tokens_indices_dp = result
770
+
771
+ # All lists should be empty since no placeholders
772
+ assert token_in_tpu_cur_input_indices_dp[0] == []
773
+ assert token_in_tpu_pre_next_tokens_indices_dp[0] == []
774
+ assert token_in_tpu_cur_input_indices_dp[1] == []
775
+ assert token_in_tpu_pre_next_tokens_indices_dp[1] == []
776
+
777
+ def test_apply_async_token_substitution_empty_indices(self):
778
+ """Test _apply_async_token_substitution with empty indices (line 1025)."""
779
+
780
+ # Bind the actual method
781
+ self.runner._apply_async_token_substitution = TPUModelRunner._apply_async_token_substitution.__get__(
782
+ self.runner)
783
+
784
+ input_ids = np.array([1, 2, 3, 4, 5])
785
+ token_in_tpu_cur_input_indices = np.array([])
786
+ token_in_tpu_pre_next_tokens_indices = np.array([])
787
+
788
+ # Setup _pre_async_results
789
+ self.runner._pre_async_results = MagicMock()
790
+ self.runner._pre_async_results.next_tokens = np.array([10, 20, 30])
791
+ self.runner.mesh = MagicMock()
792
+
793
+ result = self.runner._apply_async_token_substitution(
794
+ input_ids, token_in_tpu_cur_input_indices,
795
+ token_in_tpu_pre_next_tokens_indices)
796
+
797
+ # Should return input_ids unchanged
798
+ np.testing.assert_array_equal(result, input_ids)
799
+
800
+ @patch('tpu_inference.runner.tpu_runner.device_array',
801
+ side_effect=lambda mesh, tensors, **kwargs: tensors)
802
+ def test_apply_async_token_substitution_with_padding(
803
+ self, mock_device_array):
804
+ """Test _apply_async_token_substitution with padding."""
805
+
806
+ # Bind the actual method
807
+ self.runner._apply_async_token_substitution = TPUModelRunner._apply_async_token_substitution.__get__(
808
+ self.runner)
809
+
810
+ input_ids = np.array([1, 2, 3, 4, 5, 6, 7, 8])
811
+ # Substitute positions 2 and 5
812
+ token_in_tpu_cur_input_indices = np.array([2, 5])
813
+ token_in_tpu_pre_next_tokens_indices = np.array([0, 1])
814
+
815
+ # Setup _pre_async_results
816
+ self.runner._pre_async_results = MagicMock()
817
+ self.runner._pre_async_results.next_tokens = np.array([100, 200, 300])
818
+ self.runner.mesh = MagicMock()
819
+ self.runner.maybe_forbid_compile = nullcontext()
820
+
821
+ # Mock the substitute function to verify it's called correctly
822
+ mock_substitute_fn = MagicMock(
823
+ return_value=np.array([1, 2, 100, 4, 5, 200, 7, 8]))
824
+ self.runner._substitute_placeholder_token_fn = mock_substitute_fn
825
+
826
+ _ = self.runner._apply_async_token_substitution(
827
+ input_ids, token_in_tpu_cur_input_indices,
828
+ token_in_tpu_pre_next_tokens_indices)
829
+
830
+ # Verify the substitute function was called
831
+ mock_substitute_fn.assert_called_once()
832
+ call_args = mock_substitute_fn.call_args[0]
833
+
834
+ # Verify input_ids
835
+ np.testing.assert_array_equal(call_args[0], input_ids)
836
+
837
+ # Verify padded indices length matches input_ids length
838
+ assert len(call_args[1]) == len(input_ids)
839
+ assert len(call_args[2]) == len(input_ids)
840
+
841
+ # Verify placeholder_num
842
+ assert call_args[4] == 2 # Number of actual substitutions
843
+
844
+ def test_prepare_inputs_routing_to_dp(self):
845
+ """Test _prepare_inputs routes to _prepare_inputs_dp when dp_size > 1."""
846
+
847
+ # Bind the actual _prepare_inputs method
848
+ self.runner._prepare_inputs = TPUModelRunner._prepare_inputs.__get__(
849
+ self.runner)
850
+
851
+ self.runner.dp_size = 2
852
+ self.runner._prepare_inputs_dp = MagicMock(return_value=(None, None,
853
+ None, None,
854
+ None, None))
855
+
856
+ scheduler_output = MagicMock()
857
+ self.runner._prepare_inputs(scheduler_output)
858
+
859
+ # Verify _prepare_inputs_dp was called
860
+ self.runner._prepare_inputs_dp.assert_called_once_with(
861
+ scheduler_output)
862
+
863
+ def test_prepare_inputs_routing_to_non_dp(self):
864
+ """Test _prepare_inputs routes to _prepare_inputs_non_dp when dp_size == 1."""
865
+
866
+ # Bind the actual _prepare_inputs method
867
+ self.runner._prepare_inputs = TPUModelRunner._prepare_inputs.__get__(
868
+ self.runner)
869
+
870
+ self.runner.dp_size = 1
871
+ self.runner._prepare_inputs_non_dp = MagicMock(
872
+ return_value=(None, None, None, None, None, None, None))
873
+
874
+ scheduler_output = MagicMock()
875
+ self.runner._prepare_inputs(scheduler_output)
876
+
877
+ # Verify _prepare_inputs_non_dp was called
878
+ self.runner._prepare_inputs_non_dp.assert_called_once_with(
879
+ scheduler_output)
880
+
881
+ @patch('tpu_inference.runner.tpu_runner.NamedSharding')
882
+ @patch('tpu_inference.runner.tpu_runner.runner_utils')
883
+ @patch('tpu_inference.runner.tpu_runner.device_array',
884
+ side_effect=lambda mesh, tensors, **kwargs: tensors)
885
+ @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
886
+ def test_prepare_inputs_dp_with_async_scheduling(self,
887
+ mock_sampling_metadata,
888
+ mock_device_array,
889
+ mock_runner_utils,
890
+ mock_named_sharding):
891
+
892
+ # Setup mocking
893
+ def mock_get_padded_token_len(paddings_list, val):
894
+ if val <= 2:
895
+ return 4
896
+ elif val <= 5:
897
+ return 8
898
+ else:
899
+ return 16
900
+
901
+ mock_runner_utils.get_padded_token_len.side_effect = mock_get_padded_token_len
902
+ mock_sampling_instance = MagicMock()
903
+ mock_sampling_metadata.from_input_batch.return_value = mock_sampling_instance
904
+ mock_named_sharding.return_value = MagicMock()
905
+
906
+ # Setup test data
907
+ num_scheduled_tokens = {"req1": 3, "req2": 2}
908
+ assigned_dp_ranks = {"req1": 0, "req2": 1}
909
+
910
+ self.runner.input_batch.num_reqs = 2
911
+ self.runner.input_batch.req_ids = ["req1", "req2"]
912
+ self.runner.input_batch.num_computed_tokens_cpu = np.array([4, 6])
913
+ self.runner.input_batch.token_ids_cpu = np.zeros((8, 64),
914
+ dtype=np.int32)
915
+
916
+ scheduler_output = self._create_mock_scheduler_output(
917
+ num_scheduled_tokens, assigned_dp_ranks)
918
+
919
+ # Enable async scheduling
920
+ self.runner.scheduler_config.async_scheduling = True
921
+ self.runner._pre_async_results = MagicMock()
922
+ self.runner._pre_async_results.placeholder_req_id_to_index = {
923
+ "req1": 0
924
+ }
925
+ self.runner._pre_async_results.next_tokens = np.array([100])
926
+
927
+ # Setup required attributes
928
+ self.runner.uses_mrope = False
929
+ self.runner.phase_based_profiler = None
930
+ self.runner.lora_config = None
931
+ self.runner.mesh = MagicMock()
932
+ self.runner.data_parallel_sharding = MagicMock()
933
+ self.runner.data_parallel_attn_sharding = MagicMock()
934
+ self.runner.mm_manager = MagicMock()
935
+ self.runner.speculative_decoding_manager = MagicMock()
936
+ self.runner.lora_utils = MagicMock()
937
+
938
+ # Mock the token substitution preparation
939
+ mock_prepare_async = MagicMock(return_value=({
940
+ 0: [2],
941
+ 1: []
942
+ }, {
943
+ 0: [0],
944
+ 1: []
945
+ }))
946
+ self.runner._prepare_async_token_substitution_indices_dp = mock_prepare_async
947
+
948
+ # Execute the method
949
+ _ = self.runner._prepare_inputs_dp(scheduler_output)
950
+
951
+ # Verify async token substitution was called
952
+ mock_prepare_async.assert_called_once()
953
+
954
+ @patch('tpu_inference.runner.tpu_runner.NamedSharding')
955
+ @patch('tpu_inference.runner.tpu_runner.runner_utils')
956
+ @patch('tpu_inference.runner.tpu_runner.device_array',
957
+ side_effect=lambda mesh, tensors, **kwargs: tensors)
958
+ @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
959
+ def test_prepare_inputs_dp_async_token_substitution_application(
960
+ self, mock_sampling_metadata, mock_device_array, mock_runner_utils,
961
+ mock_named_sharding):
962
+ """Test async token substitution application in DP mode."""
963
+
964
+ # Setup mocking
965
+ def mock_get_padded_token_len(paddings_list, val):
966
+ if val <= 2:
967
+ return 4
968
+ elif val <= 5:
969
+ return 8
970
+ else:
971
+ return 16
972
+
973
+ mock_runner_utils.get_padded_token_len.side_effect = mock_get_padded_token_len
974
+ mock_sampling_instance = MagicMock()
975
+ mock_sampling_metadata.from_input_batch.return_value = mock_sampling_instance
976
+ mock_named_sharding.return_value = MagicMock()
977
+
978
+ # Setup test data
979
+ num_scheduled_tokens = {"req1": 3, "req2": 2}
980
+ assigned_dp_ranks = {"req1": 0, "req2": 1}
981
+
982
+ self.runner.input_batch.num_reqs = 2
983
+ self.runner.input_batch.req_ids = ["req1", "req2"]
984
+ self.runner.input_batch.num_computed_tokens_cpu = np.array([4, 6])
985
+ self.runner.input_batch.token_ids_cpu = np.zeros((8, 64),
986
+ dtype=np.int32)
987
+
988
+ scheduler_output = self._create_mock_scheduler_output(
989
+ num_scheduled_tokens, assigned_dp_ranks)
990
+
991
+ # Enable async scheduling with placeholders
992
+ self.runner.scheduler_config.async_scheduling = True
993
+ self.runner._pre_async_results = MagicMock()
994
+ self.runner._pre_async_results.placeholder_req_id_to_index = {
995
+ "req1": 0,
996
+ "req2": 1
997
+ }
998
+ self.runner._pre_async_results.next_tokens = np.array([100, 200])
999
+
1000
+ # Setup required attributes
1001
+ self.runner.uses_mrope = False
1002
+ self.runner.phase_based_profiler = None
1003
+ self.runner.lora_config = None
1004
+ self.runner.mesh = MagicMock()
1005
+ self.runner.data_parallel_sharding = MagicMock()
1006
+ self.runner.data_parallel_attn_sharding = MagicMock()
1007
+ self.runner.mm_manager = MagicMock()
1008
+ self.runner.speculative_decoding_manager = MagicMock()
1009
+ self.runner.lora_utils = MagicMock()
1010
+
1011
+ # Mock the async token substitution application
1012
+ mock_apply_async = MagicMock(
1013
+ return_value=np.array([1, 2, 100, 4, 5, 200, 7, 8]))
1014
+ self.runner._apply_async_token_substitution = mock_apply_async
1015
+
1016
+ # Execute the method
1017
+ _ = self.runner._prepare_inputs_dp(scheduler_output)
1018
+
1019
+ # Verify _apply_async_token_substitution was called
1020
+ mock_apply_async.assert_called_once()
1021
+ call_args = mock_apply_async.call_args[0]
1022
+
1023
+ # Verify indices were concatenated from both DP ranks
1024
+ token_in_tpu_cur_input_indices = call_args[1]
1025
+ token_in_tpu_pre_next_tokens_indices = call_args[2]
1026
+
1027
+ # Should have indices from both ranks
1028
+ assert len(token_in_tpu_cur_input_indices) == 2
1029
+ assert len(token_in_tpu_pre_next_tokens_indices) == 2
1030
+
1031
+
1032
+ if __name__ == "__main__":
1033
+ pytest.main([__file__])