tpu-inference 0.12.0.dev20251213__py3-none-any.whl → 0.13.2.dev20251230__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (248) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +14 -0
  31. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  32. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_test.py +14 -0
  35. tests/layers/__init__.py +13 -0
  36. tests/layers/common/__init__.py +13 -0
  37. tests/layers/common/test_attention_interface.py +156 -0
  38. tests/layers/common/test_quantization.py +149 -0
  39. tests/layers/jax/__init__.py +13 -0
  40. tests/layers/jax/attention/__init__.py +13 -0
  41. tests/layers/jax/attention/test_common_attention.py +103 -0
  42. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  43. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  44. tests/layers/jax/moe/__init__.py +13 -0
  45. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  46. tests/layers/jax/sample/__init__.py +13 -0
  47. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  48. tests/layers/jax/sample/test_sampling.py +115 -0
  49. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  50. tests/layers/jax/test_layers.py +155 -0
  51. tests/{test_quantization.py → layers/jax/test_qwix.py} +180 -50
  52. tests/layers/jax/test_rope.py +93 -0
  53. tests/layers/jax/test_sharding.py +159 -0
  54. tests/layers/jax/test_transformer_block.py +152 -0
  55. tests/layers/vllm/__init__.py +13 -0
  56. tests/layers/vllm/test_attention.py +363 -0
  57. tests/layers/vllm/test_awq.py +406 -0
  58. tests/layers/vllm/test_compressed_tensors_moe.py +199 -0
  59. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +441 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +443 -0
  61. tests/layers/vllm/test_fp8.py +17 -0
  62. tests/layers/vllm/test_mxfp4.py +320 -0
  63. tests/layers/vllm/test_unquantized.py +662 -0
  64. tests/layers/vllm/utils.py +87 -0
  65. tests/lora/__init__.py +13 -0
  66. tests/lora/conftest.py +14 -0
  67. tests/lora/test_bgmv.py +14 -0
  68. tests/lora/test_layers.py +25 -8
  69. tests/lora/test_lora.py +15 -1
  70. tests/lora/test_lora_perf.py +14 -0
  71. tests/models/__init__.py +13 -0
  72. tests/models/common/__init__.py +13 -0
  73. tests/models/common/test_model_loader.py +455 -0
  74. tests/models/jax/__init__.py +13 -0
  75. tests/models/jax/test_deepseek_v3.py +401 -0
  76. tests/models/jax/test_llama3.py +184 -0
  77. tests/models/jax/test_llama4.py +298 -0
  78. tests/models/jax/test_llama_eagle3.py +197 -0
  79. tests/models/jax/test_llama_guard_4.py +242 -0
  80. tests/models/jax/test_qwen2.py +172 -0
  81. tests/models/jax/test_qwen2_5_vl.py +605 -0
  82. tests/models/jax/test_qwen3.py +169 -0
  83. tests/models/jax/test_weight_loading.py +180 -0
  84. tests/models/jax/utils/__init__.py +13 -0
  85. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  86. tests/platforms/__init__.py +13 -0
  87. tests/platforms/test_tpu_platform.py +54 -0
  88. tests/runner/__init__.py +13 -0
  89. tests/runner/test_block_table.py +395 -0
  90. tests/runner/test_input_batch.py +226 -0
  91. tests/runner/test_kv_cache.py +220 -0
  92. tests/runner/test_kv_cache_manager.py +498 -0
  93. tests/runner/test_multimodal_manager.py +429 -0
  94. tests/runner/test_persistent_batch_manager.py +84 -0
  95. tests/runner/test_speculative_decoding_manager.py +368 -0
  96. tests/runner/test_structured_decoding_manager.py +220 -0
  97. tests/runner/test_tpu_runner.py +261 -0
  98. tests/runner/test_tpu_runner_dp.py +1099 -0
  99. tests/runner/test_tpu_runner_mesh.py +200 -0
  100. tests/runner/test_utils.py +411 -0
  101. tests/spec_decode/__init__.py +13 -0
  102. tests/spec_decode/test_eagle3.py +311 -0
  103. tests/test_base.py +14 -0
  104. tests/test_tpu_info.py +14 -0
  105. tests/test_utils.py +1 -43
  106. tests/worker/__init__.py +13 -0
  107. tests/worker/tpu_worker_test.py +414 -0
  108. tpu_inference/__init__.py +14 -0
  109. tpu_inference/core/__init__.py +13 -0
  110. tpu_inference/core/sched/__init__.py +13 -0
  111. tpu_inference/core/sched/dp_scheduler.py +372 -56
  112. tpu_inference/distributed/__init__.py +13 -0
  113. tpu_inference/distributed/jax_parallel_state.py +14 -0
  114. tpu_inference/distributed/tpu_connector.py +14 -9
  115. tpu_inference/distributed/utils.py +56 -4
  116. tpu_inference/executors/__init__.py +13 -0
  117. tpu_inference/executors/ray_distributed_executor.py +20 -3
  118. tpu_inference/experimental/__init__.py +13 -0
  119. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  120. tpu_inference/kernels/__init__.py +13 -0
  121. tpu_inference/kernels/collectives/__init__.py +13 -0
  122. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  123. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  124. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  125. tpu_inference/kernels/fused_moe/v1/kernel.py +171 -163
  126. tpu_inference/kernels/megablox/__init__.py +13 -0
  127. tpu_inference/kernels/megablox/common.py +54 -0
  128. tpu_inference/kernels/megablox/gmm.py +646 -0
  129. tpu_inference/kernels/mla/__init__.py +13 -0
  130. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  131. tpu_inference/kernels/mla/v1/kernel.py +20 -26
  132. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  133. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  134. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  135. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  136. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +112 -69
  137. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +85 -65
  138. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  139. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +374 -194
  140. tpu_inference/kernels/ragged_paged_attention/v3/util.py +13 -0
  141. tpu_inference/layers/__init__.py +13 -0
  142. tpu_inference/layers/common/__init__.py +13 -0
  143. tpu_inference/layers/common/attention_interface.py +26 -19
  144. tpu_inference/layers/common/attention_metadata.py +14 -0
  145. tpu_inference/layers/common/fused_moe_gmm.py +506 -0
  146. tpu_inference/layers/common/quant_methods.py +15 -0
  147. tpu_inference/layers/common/quantization.py +282 -0
  148. tpu_inference/layers/common/sharding.py +22 -3
  149. tpu_inference/layers/common/utils.py +94 -0
  150. tpu_inference/layers/jax/__init__.py +13 -0
  151. tpu_inference/layers/jax/attention/__init__.py +13 -0
  152. tpu_inference/layers/jax/attention/attention.py +19 -6
  153. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +52 -27
  154. tpu_inference/layers/jax/attention/gpt_oss_attention.py +19 -6
  155. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  156. tpu_inference/layers/jax/base.py +14 -0
  157. tpu_inference/layers/jax/constants.py +13 -0
  158. tpu_inference/layers/jax/layers.py +14 -0
  159. tpu_inference/layers/jax/misc.py +14 -0
  160. tpu_inference/layers/jax/moe/__init__.py +13 -0
  161. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  162. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  163. tpu_inference/layers/jax/moe/moe.py +43 -3
  164. tpu_inference/layers/jax/pp_utils.py +53 -0
  165. tpu_inference/layers/jax/rope.py +14 -0
  166. tpu_inference/layers/jax/rope_interface.py +14 -0
  167. tpu_inference/layers/jax/sample/__init__.py +13 -0
  168. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  169. tpu_inference/layers/jax/sample/sampling.py +15 -1
  170. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  171. tpu_inference/layers/jax/transformer_block.py +14 -0
  172. tpu_inference/layers/vllm/__init__.py +13 -0
  173. tpu_inference/layers/vllm/attention.py +4 -4
  174. tpu_inference/layers/vllm/fused_moe.py +100 -455
  175. tpu_inference/layers/vllm/linear.py +64 -0
  176. tpu_inference/layers/vllm/process_weights/__init__.py +13 -0
  177. tpu_inference/layers/vllm/{sharding.py → process_weights/cleanup_sharding.py} +24 -15
  178. tpu_inference/layers/vllm/process_weights/fused_moe_weights.py +369 -0
  179. tpu_inference/layers/vllm/process_weights/linear_weights.py +174 -0
  180. tpu_inference/layers/vllm/quantization/__init__.py +19 -3
  181. tpu_inference/layers/vllm/quantization/awq.py +96 -82
  182. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  183. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +19 -5
  184. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +119 -132
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +111 -91
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +79 -43
  188. tpu_inference/layers/vllm/quantization/{common.py → configs.py} +38 -26
  189. tpu_inference/layers/vllm/quantization/fp8.py +119 -0
  190. tpu_inference/layers/vllm/quantization/mxfp4.py +133 -220
  191. tpu_inference/layers/vllm/quantization/unquantized.py +154 -253
  192. tpu_inference/lora/__init__.py +13 -0
  193. tpu_inference/lora/torch_lora_ops.py +8 -13
  194. tpu_inference/models/__init__.py +13 -0
  195. tpu_inference/models/common/__init__.py +13 -0
  196. tpu_inference/models/common/model_loader.py +37 -16
  197. tpu_inference/models/jax/__init__.py +13 -0
  198. tpu_inference/models/jax/deepseek_v3.py +113 -124
  199. tpu_inference/models/jax/gpt_oss.py +23 -7
  200. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  201. tpu_inference/models/jax/llama3.py +99 -36
  202. tpu_inference/models/jax/llama4.py +14 -0
  203. tpu_inference/models/jax/llama_eagle3.py +14 -0
  204. tpu_inference/models/jax/llama_guard_4.py +15 -1
  205. tpu_inference/models/jax/qwen2.py +17 -2
  206. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  207. tpu_inference/models/jax/qwen3.py +17 -2
  208. tpu_inference/models/jax/utils/__init__.py +13 -0
  209. tpu_inference/models/jax/utils/file_utils.py +14 -0
  210. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  211. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +85 -24
  213. tpu_inference/models/jax/utils/weight_utils.py +32 -1
  214. tpu_inference/models/vllm/__init__.py +13 -0
  215. tpu_inference/models/vllm/vllm_model_wrapper.py +22 -4
  216. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  217. tpu_inference/platforms/__init__.py +14 -0
  218. tpu_inference/platforms/tpu_platform.py +27 -29
  219. tpu_inference/runner/__init__.py +13 -0
  220. tpu_inference/runner/compilation_manager.py +69 -35
  221. tpu_inference/runner/kv_cache.py +14 -0
  222. tpu_inference/runner/kv_cache_manager.py +15 -2
  223. tpu_inference/runner/lora_utils.py +16 -1
  224. tpu_inference/runner/multimodal_manager.py +16 -2
  225. tpu_inference/runner/persistent_batch_manager.py +14 -0
  226. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  227. tpu_inference/runner/structured_decoding_manager.py +14 -0
  228. tpu_inference/runner/tpu_runner.py +30 -10
  229. tpu_inference/spec_decode/__init__.py +13 -0
  230. tpu_inference/spec_decode/jax/__init__.py +13 -0
  231. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  232. tpu_inference/tpu_info.py +14 -0
  233. tpu_inference/utils.py +31 -30
  234. tpu_inference/worker/__init__.py +13 -0
  235. tpu_inference/worker/tpu_worker.py +23 -7
  236. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/METADATA +1 -1
  237. tpu_inference-0.13.2.dev20251230.dist-info/RECORD +266 -0
  238. tpu_inference/layers/vllm/linear_common.py +0 -208
  239. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  240. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  241. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  242. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  245. tpu_inference-0.12.0.dev20251213.dist-info/RECORD +0 -175
  246. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/WHEEL +0 -0
  247. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/licenses/LICENSE +0 -0
  248. {tpu_inference-0.12.0.dev20251213.dist-info → tpu_inference-0.13.2.dev20251230.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1099 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from contextlib import nullcontext
16
+ from unittest.mock import MagicMock, patch
17
+
18
+ import numpy as np
19
+ import pytest
20
+
21
+ from tpu_inference.runner.tpu_runner import TPUModelRunner
22
+
23
+
24
+ class TestTPUJaxRunnerDPInputsLightweight:
25
+
26
+ def setup_method(self):
27
+ self.runner = MagicMock()
28
+
29
+ # Basic DP configuration
30
+ self.runner.dp_size = 2
31
+ self.runner.max_num_tokens = 64
32
+ self.runner.max_num_reqs = 8
33
+ self.runner.max_num_blocks_per_req = 8
34
+ self.runner.num_tokens_paddings = [16, 32, 64]
35
+
36
+ # Mock input batch - adjust num_reqs to match test data
37
+ self.runner.input_batch = MagicMock()
38
+ self.runner.input_batch.num_reqs = 2
39
+ self.runner.input_batch.req_ids = ["req1", "req2", "req3", "req4"]
40
+ self.runner.input_batch.req_id_to_index = {
41
+ "req1": 0,
42
+ "req2": 1,
43
+ "req3": 2,
44
+ "req4": 3
45
+ }
46
+ self.runner.input_batch.num_computed_tokens_cpu = np.array(
47
+ [10, 20, 5, 15])
48
+ self.runner.input_batch.token_ids_cpu = np.random.randint(
49
+ 0, 1000, (8, 64), dtype=np.int32)
50
+
51
+ # Mock block table
52
+ mock_block_table = MagicMock()
53
+ mock_block_table.get_cpu_tensor.return_value = np.arange(32).reshape(
54
+ 4, 8)
55
+ self.runner.input_batch.block_table = [mock_block_table]
56
+
57
+ # Initialize CPU arrays that the method modifies
58
+ self.runner.input_ids_cpu = np.zeros(64, dtype=np.int32)
59
+ self.runner.positions_cpu = np.zeros(64, dtype=np.int32)
60
+ self.runner.query_start_loc_cpu = np.zeros(10, dtype=np.int32)
61
+ self.runner.seq_lens_cpu = np.zeros(8, dtype=np.int32)
62
+ self.runner.logits_indices_cpu = np.zeros(8, dtype=np.int32)
63
+ self.runner.block_tables_cpu = [np.zeros((8, 8), dtype=np.int32)]
64
+ self.runner.arange_cpu = np.arange(64, dtype=np.int64)
65
+
66
+ # mock kv cache group
67
+ mock_kv_cache_config = MagicMock()
68
+ mock_kv_cache_group = MagicMock()
69
+ mock_kv_cache_config.kv_cache_groups = [mock_kv_cache_group]
70
+ self.runner.kv_cache_config = mock_kv_cache_config
71
+ self.runner.use_hybrid_kvcache = False
72
+
73
+ # Mock scheduler config for async scheduling
74
+ self.runner.scheduler_config = MagicMock()
75
+ self.runner.scheduler_config.async_scheduling = False # Default to False for most tests
76
+ self.runner._pre_async_results = None # Default to None for most tests
77
+
78
+ # Bind the actual methods to our mock
79
+ self.runner._prepare_inputs_dp = TPUModelRunner._prepare_inputs_dp.__get__(
80
+ self.runner)
81
+ self.runner._prepare_dp_input_metadata = TPUModelRunner._prepare_dp_input_metadata.__get__(
82
+ self.runner)
83
+ self.runner._prepare_async_token_substitution_indices_dp = TPUModelRunner._prepare_async_token_substitution_indices_dp.__get__(
84
+ self.runner)
85
+
86
+ def _create_mock_scheduler_output(self,
87
+ num_scheduled_tokens_dict,
88
+ assigned_dp_ranks,
89
+ scheduled_spec_decode_tokens=None):
90
+ """Create a minimal mock scheduler output."""
91
+ mock_output = MagicMock()
92
+ mock_output.num_scheduled_tokens = num_scheduled_tokens_dict
93
+ mock_output.assigned_dp_rank = assigned_dp_ranks
94
+ mock_output.total_num_scheduled_tokens = sum(
95
+ num_scheduled_tokens_dict.values())
96
+ mock_output.scheduled_spec_decode_tokens = scheduled_spec_decode_tokens or {}
97
+ mock_output.grammar_bitmask = None
98
+ return mock_output
99
+
100
+ def _create_mock_hybrid_kv_cache_config(self):
101
+ mock_kv_cache_config = MagicMock()
102
+ mock_kv_cache_group1 = MagicMock()
103
+ mock_kv_cache_group1.layer_names = [f'layer.{i}' for i in range(10)]
104
+ mock_kv_cache_group2 = MagicMock()
105
+ mock_kv_cache_group2.layer_names = [
106
+ f'layer.{i}' for i in range(10, 20)
107
+ ]
108
+ mock_kv_cache_config.kv_cache_groups = [
109
+ mock_kv_cache_group1, mock_kv_cache_group2
110
+ ]
111
+ self.runner.kv_cache_config = mock_kv_cache_config
112
+ self.runner.use_hybrid_kvcache = True
113
+
114
+ @patch('tpu_inference.runner.tpu_runner.NamedSharding')
115
+ @patch('tpu_inference.runner.tpu_runner.runner_utils')
116
+ @patch('tpu_inference.runner.tpu_runner.device_array',
117
+ side_effect=lambda mesh, tensors, **kwargs: tensors)
118
+ @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
119
+ def test_prepare_inputs_dp_basic_functionality(self,
120
+ mock_sampling_metadata,
121
+ mock_device_array,
122
+ mock_runner_utils,
123
+ mock_named_sharding):
124
+ """Test basic functionality of _prepare_inputs_dp."""
125
+ # Mock utility functions
126
+ mock_runner_utils.get_padded_token_len.return_value = 16
127
+ mock_sampling_metadata.from_input_batch.return_value = MagicMock()
128
+ mock_named_sharding.return_value = MagicMock()
129
+
130
+ # Create test data - only use req1 and req2 to match num_reqs=2
131
+ num_scheduled_tokens = {"req1": 5, "req2": 3}
132
+ assigned_dp_ranks = {"req1": 0, "req2": 1}
133
+ scheduler_output = self._create_mock_scheduler_output(
134
+ num_scheduled_tokens, assigned_dp_ranks)
135
+
136
+ # Execute the method
137
+ result = self.runner._prepare_inputs_dp(scheduler_output)
138
+
139
+ # Basic assertions
140
+ assert len(result) == 8
141
+ input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
142
+
143
+ # Verify utility functions were called
144
+ mock_runner_utils.get_padded_token_len.assert_called()
145
+
146
+ def test_prepare_inputs_dp_error_conditions(self):
147
+ """Test error handling in DP input preparation."""
148
+ # Test with zero scheduled tokens - should fail assertion: total_num_scheduled_tokens > 0
149
+ scheduler_output = self._create_mock_scheduler_output({}, {})
150
+ scheduler_output.total_num_scheduled_tokens = 0
151
+
152
+ with pytest.raises(AssertionError):
153
+ self.runner._prepare_inputs_dp(scheduler_output)
154
+
155
+ # Test with zero requests - should fail assertion: num_reqs > 0
156
+ self.runner.input_batch.num_reqs = 0
157
+ scheduler_output = self._create_mock_scheduler_output({"req1": 5},
158
+ {"req1": 0})
159
+
160
+ with pytest.raises(AssertionError):
161
+ self.runner._prepare_inputs_dp(scheduler_output)
162
+
163
+ @patch('tpu_inference.runner.tpu_runner.NamedSharding')
164
+ @patch('tpu_inference.runner.tpu_runner.runner_utils')
165
+ @patch('tpu_inference.runner.tpu_runner.device_array',
166
+ side_effect=lambda mesh, tensors, **kwargs: tensors)
167
+ @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
168
+ def test_prepare_inputs_dp_hybrid_kvcache(self, mock_sampling_metadata,
169
+ mock_device_array,
170
+ mock_runner_utils,
171
+ mock_named_sharding):
172
+ """Test basic functionality of _prepare_inputs_dp."""
173
+ # Mock utility functions
174
+ mock_runner_utils.get_padded_token_len.return_value = 16
175
+ mock_sampling_metadata.from_input_batch.return_value = MagicMock()
176
+ mock_named_sharding.return_value = MagicMock()
177
+
178
+ # Create test data - only use req1 and req2 to match num_reqs=2
179
+ num_scheduled_tokens = {"req1": 5, "req2": 3}
180
+ assigned_dp_ranks = {"req1": 0, "req2": 1}
181
+ scheduler_output = self._create_mock_scheduler_output(
182
+ num_scheduled_tokens, assigned_dp_ranks)
183
+
184
+ # Create hybrid kv cache config with 10 full attn layers, 10 sw attn layers
185
+ self._create_mock_hybrid_kv_cache_config()
186
+
187
+ # update input_batch's block_table
188
+ mock_block_table = MagicMock()
189
+ mock_block_table.get_cpu_tensor.return_value = np.arange(32).reshape(
190
+ 4, 8)
191
+ self.runner.input_batch.block_table = [
192
+ mock_block_table, mock_block_table
193
+ ]
194
+
195
+ # update model runner's block_tables_cpu:
196
+ self.runner.block_tables_cpu = [
197
+ np.zeros((8, 8), dtype=np.int32),
198
+ np.zeros((8, 8), dtype=np.int32)
199
+ ]
200
+
201
+ # Execute the method
202
+ result = self.runner._prepare_inputs_dp(scheduler_output)
203
+
204
+ # Basic assertions
205
+ assert len(result) == 8
206
+ input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
207
+
208
+ # Verify utility functions were called
209
+ mock_runner_utils.get_padded_token_len.assert_called()
210
+
211
+ # Verify there's attention_metadata for each layer
212
+ assert isinstance(attention_metadata, dict)
213
+ assert len(attention_metadata) == 20
214
+
215
+ def test_prepare_dp_input_metadata(self):
216
+ num_scheduled_tokens = {"req1": 10, "req2": 5, "req3": 8, "req4": 3}
217
+ assigned_dp_ranks = {"req1": 0, "req2": 0, "req3": 1, "req4": 1}
218
+
219
+ self.runner.input_batch.num_reqs = 4
220
+ self.runner.input_batch.req_ids = ["req1", "req2", "req3", "req4"]
221
+ self.runner.max_num_reqs = 8
222
+
223
+ scheduler_output = self._create_mock_scheduler_output(
224
+ num_scheduled_tokens, assigned_dp_ranks)
225
+
226
+ with patch('tpu_inference.runner.tpu_runner.runner_utils'
227
+ ) as mock_runner_utils:
228
+ mock_runner_utils.get_padded_token_len.side_effect = lambda paddings_list, val: 16 if val <= 15 else 32 # Padded tokens per DP rank
229
+
230
+ result = self.runner._prepare_dp_input_metadata(scheduler_output)
231
+
232
+ (req_ids_dp, req_indices_dp, num_scheduled_tokens_per_dp_rank,
233
+ scheduled_tokens_per_dp_rank, num_req_per_dp_rank,
234
+ padded_num_scheduled_tokens_per_dp_rank, padded_num_reqs,
235
+ padded_total_num_scheduled_tokens, padded_num_reqs_per_dp_rank,
236
+ logits_indices_selector, max_num_reqs_per_dp_rank) = result
237
+
238
+ # 1. req_ids_dp: Dictionary mapping DP rank to request IDs
239
+ assert isinstance(req_ids_dp, dict)
240
+ assert req_ids_dp[0] == ["req1", "req2"]
241
+ assert req_ids_dp[1] == ["req3", "req4"]
242
+
243
+ # 2. req_indices_dp: Dictionary mapping DP rank to request indices
244
+ assert isinstance(req_indices_dp, dict)
245
+ assert req_indices_dp[0] == [0, 1] # indices of req1, req2
246
+ assert req_indices_dp[1] == [2, 3] # indices of req3, req4
247
+
248
+ # 3. num_scheduled_tokens_per_dp_rank: Total tokens per DP rank
249
+ assert isinstance(num_scheduled_tokens_per_dp_rank, dict)
250
+ assert num_scheduled_tokens_per_dp_rank[0] == 15 # 10 + 5
251
+ assert num_scheduled_tokens_per_dp_rank[1] == 11 # 8 + 3
252
+
253
+ # 4. scheduled_tokens_per_dp_rank: List of token counts per request per DP rank
254
+ assert isinstance(scheduled_tokens_per_dp_rank, dict)
255
+ assert scheduled_tokens_per_dp_rank[0] == [10,
256
+ 5] # req1=10, req2=5
257
+ assert scheduled_tokens_per_dp_rank[1] == [8, 3] # req3=8, req4=3
258
+
259
+ # 5. num_req_per_dp_rank: Number of requests per DP rank
260
+ assert isinstance(num_req_per_dp_rank, dict)
261
+ assert num_req_per_dp_rank[0] == 2
262
+ assert num_req_per_dp_rank[1] == 2
263
+
264
+ # 6. padded_num_scheduled_tokens_per_dp_rank: Padded token count per rank
265
+ assert padded_num_scheduled_tokens_per_dp_rank == 16
266
+
267
+ # 7. padded_num_reqs: Total padded requests across all ranks
268
+ assert padded_num_reqs == 32 # 2 DP ranks * 16 padded reqs per rank
269
+
270
+ # 8. padded_total_num_scheduled_tokens: Total padded tokens across all ranks
271
+ assert padded_total_num_scheduled_tokens == 32 # 2 DP ranks * 16 padded tokens per rank
272
+
273
+ # 9. padded_num_reqs_per_dp_rank: Padded requests per DP rank
274
+ assert padded_num_reqs_per_dp_rank == 16
275
+
276
+ # 10. logits_indices_selector: Array to map back to original request order
277
+ assert isinstance(logits_indices_selector, np.ndarray)
278
+ assert len(logits_indices_selector) == 4 # One for each request
279
+ # Should map distributed positions back to original order
280
+ expected_selector = np.array([0, 1, 16, 17])
281
+ np.testing.assert_array_equal(logits_indices_selector,
282
+ expected_selector)
283
+
284
+ # 11. max_num_reqs_per_dp_rank: Maximum requests per DP rank
285
+ assert max_num_reqs_per_dp_rank == 4 # max_num_reqs (8) // dp_size (2)
286
+
287
+ def test_prepare_dp_input_metadata_empty_rank(self):
288
+ """Test metadata preparation with one empty DP rank"""
289
+ # Create test data where all requests go to rank 0, leaving rank 1 empty
290
+ num_scheduled_tokens = {"req1": 10, "req2": 5}
291
+ assigned_dp_ranks = {"req1": 0, "req2": 0}
292
+
293
+ self.runner.input_batch.num_reqs = 2
294
+ self.runner.input_batch.req_ids = ["req1", "req2"]
295
+ self.runner.max_num_reqs = 8
296
+
297
+ scheduler_output = self._create_mock_scheduler_output(
298
+ num_scheduled_tokens, assigned_dp_ranks)
299
+
300
+ with patch('tpu_inference.runner.tpu_runner.runner_utils'
301
+ ) as mock_runner_utils:
302
+ mock_runner_utils.get_padded_token_len.side_effect = lambda paddings_list, val: 16 if val <= 15 else 32
303
+
304
+ result = self.runner._prepare_dp_input_metadata(scheduler_output)
305
+
306
+ (req_ids_dp, req_indices_dp, num_scheduled_tokens_per_dp_rank,
307
+ scheduled_tokens_per_dp_rank, num_req_per_dp_rank,
308
+ padded_num_scheduled_tokens_per_dp_rank, padded_num_reqs,
309
+ padded_total_num_scheduled_tokens, padded_num_reqs_per_dp_rank,
310
+ logits_indices_selector, max_num_reqs_per_dp_rank) = result
311
+
312
+ # 1. req_ids_dp
313
+ assert isinstance(req_ids_dp, dict)
314
+ assert req_ids_dp[0] == ["req1", "req2"]
315
+ assert req_ids_dp[1] == [] # Empty rank
316
+
317
+ # 2. req_indices_dp
318
+ assert isinstance(req_indices_dp, dict)
319
+ assert req_indices_dp[0] == [0, 1] # req1, req2 indices
320
+ assert req_indices_dp[1] == [] # Empty rank
321
+
322
+ # 3. num_scheduled_tokens_per_dp_rank
323
+ assert isinstance(num_scheduled_tokens_per_dp_rank, dict)
324
+ assert num_scheduled_tokens_per_dp_rank[0] == 15 # 10 + 5
325
+ assert num_scheduled_tokens_per_dp_rank[1] == 0 # Empty rank
326
+
327
+ # 4. scheduled_tokens_per_dp_rank
328
+ assert isinstance(scheduled_tokens_per_dp_rank, dict)
329
+ assert scheduled_tokens_per_dp_rank[0] == [10,
330
+ 5] # req1=10, req2=5
331
+ assert scheduled_tokens_per_dp_rank[1] == [] # Empty rank
332
+
333
+ # 5. num_req_per_dp_rank
334
+ assert isinstance(num_req_per_dp_rank, dict)
335
+ assert num_req_per_dp_rank[0] == 2 # Both requests on rank 0
336
+ assert num_req_per_dp_rank[1] == 0 # No requests on rank 1
337
+
338
+ # 6. padded_num_scheduled_tokens_per_dp_rank
339
+ assert padded_num_scheduled_tokens_per_dp_rank == 16
340
+
341
+ # 7. padded_num_reqs
342
+ assert padded_num_reqs == 32 # 2 DP ranks * 16 padded reqs per rank
343
+
344
+ # 8. padded_total_num_scheduled_tokens
345
+ assert padded_total_num_scheduled_tokens == 32 # 2 DP ranks * 16 padded tokens per rank
346
+
347
+ # 10. padded_num_reqs_per_dp_rank: Padded requests per DP rank
348
+ assert padded_num_reqs_per_dp_rank == 16
349
+
350
+ # 11. logits_indices_selector: Should preserve original order since no reordering needed
351
+ assert isinstance(logits_indices_selector, np.ndarray)
352
+ assert len(logits_indices_selector) == 2
353
+ # Both requests on DP rank 0, positions 0 and 1
354
+ expected_selector = np.array([0, 1])
355
+ np.testing.assert_array_equal(logits_indices_selector,
356
+ expected_selector)
357
+
358
+ # 12. max_num_reqs_per_dp_rank: Maximum requests per DP rank
359
+ assert max_num_reqs_per_dp_rank == 4 # max_num_reqs (8) // dp_size (2)
360
+
361
+ def test_prepare_dp_input_metadata_logits_indices_selector_ordering(self):
362
+ """Test logits_indices_selector with mixed DP rank assignment."""
363
+ # Create requests with mixed assignment to test reordering
364
+ num_scheduled_tokens = {"req1": 4, "req2": 6, "req3": 2}
365
+ assigned_dp_ranks = {
366
+ "req1": 1,
367
+ "req2": 0,
368
+ "req3": 1
369
+ } # req2 on rank 0, req1&req3 on rank 1
370
+
371
+ self.runner.input_batch.num_reqs = 3
372
+ self.runner.input_batch.req_ids = ["req1", "req2", "req3"]
373
+
374
+ scheduler_output = self._create_mock_scheduler_output(
375
+ num_scheduled_tokens, assigned_dp_ranks)
376
+
377
+ with patch('tpu_inference.runner.tpu_runner.runner_utils'
378
+ ) as mock_runner_utils:
379
+ mock_runner_utils.get_padded_token_len.side_effect = lambda paddings_list, val: 8 if val <= 6 else 16
380
+
381
+ result = self.runner._prepare_dp_input_metadata(scheduler_output)
382
+
383
+ (req_ids_dp, req_indices_dp, _, _, _, _, _, _, _,
384
+ logits_indices_selector, _) = result
385
+
386
+ # Verify request distribution
387
+ assert req_ids_dp[0] == ["req2"] # rank 0: req2 (index 1)
388
+ assert req_ids_dp[1] == [
389
+ "req1", "req3"
390
+ ] # rank 1: req1 (index 0), req3 (index 2)
391
+
392
+ assert req_indices_dp[0] == [1] # req2 has original index 1
393
+ assert req_indices_dp[1] == [
394
+ 0, 2
395
+ ] # req1 has index 0, req3 has index 2
396
+
397
+ # The logits_indices_selector should map the DP-distributed positions back to original order
398
+
399
+ assert isinstance(logits_indices_selector, np.ndarray)
400
+ assert len(logits_indices_selector) == 3
401
+
402
+ expected_positions = np.array([8, 0, 9])
403
+ np.testing.assert_array_equal(logits_indices_selector,
404
+ expected_positions)
405
+
406
+ @patch('tpu_inference.runner.tpu_runner.NamedSharding')
407
+ @patch('tpu_inference.runner.tpu_runner.runner_utils')
408
+ @patch('tpu_inference.runner.tpu_runner.device_array',
409
+ side_effect=lambda mesh, tensors, **kwargs: tensors)
410
+ @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
411
+ def test_prepare_inputs_dp_verify_content_balanced(self,
412
+ mock_sampling_metadata,
413
+ mock_device_array,
414
+ mock_runner_utils,
415
+ mock_named_sharding):
416
+ """Test _prepare_inputs_dp with content verification for balanced distribution."""
417
+
418
+ # Setup mocking with specific behavior for tokens vs requests
419
+ def mock_get_padded_token_len(paddings_list, val):
420
+ # For tokens: 8 if val <= 3 else 16
421
+ # For requests: 4 if val <= 1 else 8
422
+ if val <= 1:
423
+ return 4 # For request padding
424
+ elif val <= 3:
425
+ return 8 # For token padding
426
+ else:
427
+ return 16
428
+
429
+ mock_runner_utils.get_padded_token_len.side_effect = mock_get_padded_token_len
430
+ mock_sampling_instance = MagicMock()
431
+ mock_sampling_metadata.from_input_batch.return_value = mock_sampling_instance
432
+ mock_named_sharding.return_value = MagicMock()
433
+
434
+ # Setup deterministic test data
435
+ num_scheduled_tokens = {"req1": 2, "req2": 3}
436
+ assigned_dp_ranks = {"req1": 0, "req2": 1}
437
+
438
+ self.runner.input_batch.num_reqs = 2
439
+ self.runner.input_batch.req_ids = ["req1", "req2"]
440
+ self.runner.input_batch.num_computed_tokens_cpu = np.array(
441
+ [5, 6]) # Starting positions
442
+
443
+ # Setup known token sequences for verification
444
+ self.runner.input_batch.token_ids_cpu = np.zeros((8, 64),
445
+ dtype=np.int32)
446
+ # req1: [1001, 1002, 1003, ...]
447
+ # req2: [2001, 2002, 2003, ...]
448
+ for i in range(2):
449
+ start_val = (i + 1) * 1000 + 1
450
+ for j in range(64):
451
+ self.runner.input_batch.token_ids_cpu[i, j] = start_val + j
452
+
453
+ scheduler_output = self._create_mock_scheduler_output(
454
+ num_scheduled_tokens, assigned_dp_ranks)
455
+
456
+ # Setup additional required attributes
457
+ self.runner.uses_mrope = False
458
+ self.runner.phase_based_profiler = None
459
+ self.runner.lora_config = None
460
+ self.runner.mesh = MagicMock()
461
+ self.runner.data_parallel_sharding = MagicMock()
462
+ self.runner.data_parallel_attn_sharding = MagicMock()
463
+ self.runner.mm_manager = MagicMock()
464
+ self.runner.speculative_decoding_manager = MagicMock()
465
+ self.runner.lora_utils = MagicMock()
466
+ # self.runner.mrope_positions_cpu = np.zeros((3, 64), dtype=np.int64)
467
+
468
+ # Execute the method
469
+ result = self.runner._prepare_inputs_dp(scheduler_output)
470
+ input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
471
+ # 1. Verify input_ids content
472
+ expected_input_ids = np.zeros(16, dtype=np.int32)
473
+ expected_input_ids[:2] = [1006, 1007]
474
+ expected_input_ids[8:11] = [2007, 2008, 2009]
475
+ assert np.array_equal(input_ids, expected_input_ids)
476
+
477
+ # 2. Verify attention_metadata positions content
478
+ expected_positions = np.zeros(16, dtype=np.int32)
479
+ expected_positions[:2] = [5, 6] # req1 positions
480
+ expected_positions[8:11] = [6, 7, 8]
481
+ assert np.array_equal(attention_metadata.input_positions,
482
+ expected_positions)
483
+
484
+ # 3. Verify query_start_loc content
485
+ query_start_loc = attention_metadata.query_start_loc_cpu
486
+ max_num_reqs_per_dp = self.runner.max_num_reqs // 2
487
+ expected_query_start = np.zeros(self.runner.max_num_reqs + 2,
488
+ dtype=np.int32)
489
+ # DP rank 0: cumsum([2]) = [2] at positions [1:2] → [0, 2, 1, 1, 1]
490
+ expected_query_start[1] = 2 # req1 has 2 tokens
491
+ expected_query_start[2:max_num_reqs_per_dp + 1] = 1
492
+ # DP rank 1: cumsum([3]) = [3] at positions [6:7] → [0, 3, 1, 1, 1]
493
+ expected_query_start[max_num_reqs_per_dp + 2] = 3 # req2 has 3 tokens
494
+ expected_query_start[max_num_reqs_per_dp + 3:] = 1
495
+ assert np.array_equal(query_start_loc, expected_query_start)
496
+
497
+ # 4. Verify seq_lens content
498
+ seq_lens = attention_metadata.seq_lens_cpu
499
+ # Should be computed_tokens + scheduled_tokens for each request
500
+ # DP rank 0: req1 at position 0, DP rank 1: req2 at position 4
501
+ expected_seq_lens = np.array([7, 0, 0, 0, 9, 0, 0,
502
+ 0]) # req1: 5+2=7, req2: 6+3=9
503
+ assert np.array_equal(seq_lens, expected_seq_lens)
504
+
505
+ # 5. Verify request_distribution content
506
+ expected_distribution = np.array([[0, 0, 1], [0, 0, 1]]).flatten()
507
+ np.testing.assert_array_equal(attention_metadata.request_distribution,
508
+ expected_distribution)
509
+
510
+ # 6. Verify logits_indices content
511
+ assert len(logits_indices) == 8 # padded_num_reqs
512
+ expected_logits = np.full(8, -1, dtype=np.int32)
513
+ expected_logits[0] = 1 # req1 last token position (2-1)
514
+ expected_logits[
515
+ 4] = 2 # req2 last token position (3-1) at DP rank 1 offset (4*1)
516
+ assert np.array_equal(logits_indices, expected_logits)
517
+
518
+ # 7. Verify logits_indices_selector
519
+ assert len(logits_indices_selector) == 2
520
+ assert np.array_equal(logits_indices_selector, np.array([0, 4]))
521
+
522
+ @patch('tpu_inference.runner.tpu_runner.NamedSharding')
523
+ @patch('tpu_inference.runner.tpu_runner.runner_utils')
524
+ @patch('tpu_inference.runner.tpu_runner.device_array',
525
+ side_effect=lambda mesh, tensors, **kwargs: tensors)
526
+ @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
527
+ def test_prepare_inputs_dp_verify_content_empty_rank(
528
+ self, mock_sampling_metadata, mock_device_array, mock_runner_utils,
529
+ mock_named_sharding):
530
+ """Test _prepare_inputs_dp with detailed content verification for empty rank case."""
531
+
532
+ # Setup mocking
533
+ def mock_get_padded_token_len(paddings_list, val):
534
+ if val <= 2:
535
+ return 4 # For request padding (max 2 requests)
536
+ elif val <= 5:
537
+ return 8 # For token padding
538
+ else:
539
+ return 16
540
+
541
+ mock_runner_utils.get_padded_token_len.side_effect = mock_get_padded_token_len
542
+ mock_sampling_instance = MagicMock()
543
+ mock_sampling_metadata.from_input_batch.return_value = mock_sampling_instance
544
+ mock_named_sharding.return_value = MagicMock()
545
+
546
+ # Setup test data with all requests on rank 0 (empty rank 1)
547
+ num_scheduled_tokens = {"req1": 3, "req2": 2}
548
+ assigned_dp_ranks = {
549
+ "req1": 0,
550
+ "req2": 0
551
+ } # Both on rank 0, rank 1 empty
552
+
553
+ self.runner.input_batch.num_reqs = 2
554
+ self.runner.input_batch.req_ids = ["req1", "req2"]
555
+ self.runner.input_batch.num_computed_tokens_cpu = np.array(
556
+ [4, 6]) # Starting positions
557
+
558
+ # Setup deterministic token sequences for verification
559
+ self.runner.input_batch.token_ids_cpu = np.zeros((8, 64),
560
+ dtype=np.int32)
561
+ # req1: [5001, 5002, 5003, ...] starting at position 4
562
+ # req2: [6001, 6002, 6003, ...] starting at position 6
563
+ for i in range(2):
564
+ start_val = (i + 5) * 1000 + 1 # 5001, 6001
565
+ for j in range(64):
566
+ self.runner.input_batch.token_ids_cpu[i, j] = start_val + j
567
+
568
+ scheduler_output = self._create_mock_scheduler_output(
569
+ num_scheduled_tokens, assigned_dp_ranks)
570
+
571
+ # Setup required attributes
572
+ self.runner.uses_mrope = False
573
+ self.runner.phase_based_profiler = None
574
+ self.runner.lora_config = None
575
+ self.runner.mesh = MagicMock()
576
+ self.runner.data_parallel_sharding = MagicMock()
577
+ self.runner.data_parallel_attn_sharding = MagicMock()
578
+ self.runner.mm_manager = MagicMock()
579
+ self.runner.speculative_decoding_manager = MagicMock()
580
+ self.runner.lora_utils = MagicMock()
581
+
582
+ # Execute the method
583
+ result = self.runner._prepare_inputs_dp(scheduler_output)
584
+ input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
585
+
586
+ # 1. Verify input_ids
587
+ expected_input_ids = np.zeros(16, dtype=np.int32)
588
+ # Rank 0
589
+ expected_input_ids[:5] = [5005, 5006, 5007, 6007, 6008]
590
+ # Rank 1 (positions 8-15) should remain zeros
591
+ assert np.array_equal(input_ids, expected_input_ids)
592
+
593
+ # 2. Verify attention_metadata
594
+ expected_positions = np.zeros(16, dtype=np.int32)
595
+ expected_positions[:3] = [4, 5, 6] # req1 positions: 4 + [0, 1, 2]
596
+ expected_positions[3:5] = [6, 7] # req2 positions: 6 + [0, 1]
597
+ # Rank 1 positions (8-15) remain zeros
598
+ assert np.array_equal(attention_metadata.input_positions,
599
+ expected_positions)
600
+
601
+ # 3. Verify query_start_loc
602
+ query_start_loc = attention_metadata.query_start_loc_cpu
603
+ max_num_reqs_per_dp = self.runner.max_num_reqs // 2 # 4
604
+ expected_query_start = np.zeros(self.runner.max_num_reqs + 2,
605
+ dtype=np.int32)
606
+ # Rank 0: req1 (3 tokens), req2 (2 tokens)
607
+ expected_query_start[1] = 3 # req1 has 3 tokens
608
+ expected_query_start[2] = 5 # cumulative: 3 + 2 = 5
609
+ expected_query_start[3:max_num_reqs_per_dp + 1] = 1 # padding
610
+ # Rank 1: empty (all zeros)
611
+ expected_query_start[max_num_reqs_per_dp +
612
+ 1:] = 0 # Empty rank sets to 0
613
+ assert np.array_equal(query_start_loc, expected_query_start)
614
+
615
+ # 4. Verify seq_lens
616
+ seq_lens = attention_metadata.seq_lens_cpu
617
+ expected_seq_lens = np.zeros(8, dtype=np.int32)
618
+ # Rank 0: req1 (4+3=7), req2 (6+2=8), then padding
619
+ expected_seq_lens[
620
+ 0] = 7 # req1: computed_tokens(4) + scheduled_tokens(3)
621
+ expected_seq_lens[
622
+ 1] = 8 # req2: computed_tokens(6) + scheduled_tokens(2)
623
+ # Rank 1: all zeros
624
+ assert np.array_equal(seq_lens, expected_seq_lens)
625
+
626
+ # 5. Verify request_distribution
627
+ expected_distribution = np.array([[0, 0, 2], [0, 0, 0]]).flatten()
628
+ np.testing.assert_array_equal(attention_metadata.request_distribution,
629
+ expected_distribution)
630
+
631
+ # 6. Verify logits_indices
632
+ assert len(
633
+ logits_indices) == 8 # padded_num_reqs (8 in this case, not 16)
634
+ # Rank 0: req1 ends at pos 2, req2 ends at pos 4
635
+ # Rank 1: empty, so -1 padding
636
+ expected_logits = np.full(8, -1, dtype=np.int32)
637
+ expected_logits[0] = 2 # req1 ends at position 2 (3-1)
638
+ expected_logits[1] = 4 # req2 ends at position 4 (5-1)
639
+ assert np.array_equal(logits_indices, expected_logits)
640
+
641
+ # 7. Verify logits_indices_selector
642
+ assert len(logits_indices_selector) == 2
643
+ expected_selector = np.array([0, 1])
644
+ np.testing.assert_array_equal(logits_indices_selector,
645
+ expected_selector)
646
+
647
+ @patch('tpu_inference.runner.tpu_runner.NamedSharding')
648
+ @patch('tpu_inference.runner.tpu_runner.runner_utils')
649
+ @patch('tpu_inference.runner.tpu_runner.device_array',
650
+ side_effect=lambda mesh, tensors, **kwargs: tensors)
651
+ @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
652
+ def test_prepare_inputs_dp_with_decode_requests(self,
653
+ mock_sampling_metadata,
654
+ mock_device_array,
655
+ mock_runner_utils,
656
+ mock_named_sharding):
657
+ """Test _prepare_inputs_dp with decode requests (1 token each) to verify request_distribution."""
658
+
659
+ # Setup mocking
660
+ def mock_get_padded_token_len(paddings_list, val):
661
+ if val <= 2:
662
+ return 4 # For request padding
663
+ elif val <= 4:
664
+ return 8 # For token padding
665
+ else:
666
+ return 16
667
+
668
+ mock_runner_utils.get_padded_token_len.side_effect = mock_get_padded_token_len
669
+ mock_sampling_instance = MagicMock()
670
+ mock_sampling_metadata.from_input_batch.return_value = mock_sampling_instance
671
+ mock_named_sharding.return_value = MagicMock()
672
+
673
+ # Setup test data with decode requests (1 token) and prefill requests (>1 token)
674
+ # req1: decode (1 token), req2: decode (1 token), req3: prefill (3 tokens), req4: decode (1 token)
675
+ num_scheduled_tokens = {"req1": 1, "req2": 1, "req3": 3, "req4": 1}
676
+ assigned_dp_ranks = {"req1": 0, "req2": 0, "req3": 1, "req4": 1}
677
+
678
+ self.runner.input_batch.num_reqs = 4
679
+ self.runner.input_batch.req_ids = ["req1", "req2", "req3", "req4"]
680
+ self.runner.input_batch.num_computed_tokens_cpu = np.array(
681
+ [5, 6, 7, 8])
682
+ self.runner.input_batch.token_ids_cpu = np.zeros((8, 64),
683
+ dtype=np.int32)
684
+
685
+ scheduler_output = self._create_mock_scheduler_output(
686
+ num_scheduled_tokens, assigned_dp_ranks)
687
+
688
+ # Setup required attributes
689
+ self.runner.uses_mrope = False
690
+ self.runner.phase_based_profiler = None
691
+ self.runner.lora_config = None
692
+ self.runner.mesh = MagicMock()
693
+ self.runner.data_parallel_sharding = MagicMock()
694
+ self.runner.data_parallel_attn_sharding = MagicMock()
695
+ self.runner.mm_manager = MagicMock()
696
+ self.runner.speculative_decoding_manager = MagicMock()
697
+ self.runner.lora_utils = MagicMock()
698
+
699
+ # Execute the method
700
+ result = self.runner._prepare_inputs_dp(scheduler_output)
701
+ input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
702
+
703
+ # Verify request_distribution
704
+ # DP rank 0: req1 (decode), req2 (decode) -> [2, 2, 2]
705
+ # DP rank 1: req3 (prefill), req4 (decode) -> [1, 1, 2]
706
+ expected_distribution = np.array([[2, 2, 2], [1, 1, 2]]).flatten()
707
+ np.testing.assert_array_equal(attention_metadata.request_distribution,
708
+ expected_distribution)
709
+
710
+ @patch('tpu_inference.runner.tpu_runner.NamedSharding')
711
+ @patch('tpu_inference.runner.tpu_runner.runner_utils')
712
+ @patch('tpu_inference.runner.tpu_runner.device_array',
713
+ side_effect=lambda mesh, tensors, **kwargs: tensors)
714
+ @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
715
+ def test_prepare_inputs_dp_all_decode_requests(self,
716
+ mock_sampling_metadata,
717
+ mock_device_array,
718
+ mock_runner_utils,
719
+ mock_named_sharding):
720
+ """Test _prepare_inputs_dp with all decode requests."""
721
+
722
+ # Setup mocking
723
+ def mock_get_padded_token_len(paddings_list, val):
724
+ if val <= 2:
725
+ return 4
726
+ elif val <= 4:
727
+ return 8
728
+ else:
729
+ return 16
730
+
731
+ mock_runner_utils.get_padded_token_len.side_effect = mock_get_padded_token_len
732
+ mock_sampling_instance = MagicMock()
733
+ mock_sampling_metadata.from_input_batch.return_value = mock_sampling_instance
734
+ mock_named_sharding.return_value = MagicMock()
735
+
736
+ # All requests are decode (1 token each)
737
+ num_scheduled_tokens = {"req1": 1, "req2": 1}
738
+ assigned_dp_ranks = {"req1": 0, "req2": 1}
739
+
740
+ self.runner.input_batch.num_reqs = 2
741
+ self.runner.input_batch.req_ids = ["req1", "req2"]
742
+ self.runner.input_batch.num_computed_tokens_cpu = np.array([5, 6])
743
+ self.runner.input_batch.token_ids_cpu = np.zeros((8, 64),
744
+ dtype=np.int32)
745
+
746
+ scheduler_output = self._create_mock_scheduler_output(
747
+ num_scheduled_tokens, assigned_dp_ranks)
748
+
749
+ # Setup required attributes
750
+ self.runner.uses_mrope = False
751
+ self.runner.phase_based_profiler = None
752
+ self.runner.lora_config = None
753
+ self.runner.mesh = MagicMock()
754
+ self.runner.data_parallel_sharding = MagicMock()
755
+ self.runner.data_parallel_attn_sharding = MagicMock()
756
+ self.runner.mm_manager = MagicMock()
757
+ self.runner.speculative_decoding_manager = MagicMock()
758
+ self.runner.lora_utils = MagicMock()
759
+
760
+ # Execute the method
761
+ result = self.runner._prepare_inputs_dp(scheduler_output)
762
+ input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
763
+
764
+ # Verify request_distribution
765
+ # Both ranks have only decode requests
766
+ # DP rank 0: req1 (decode) -> [1, 1, 1]
767
+ # DP rank 1: req2 (decode) -> [1, 1, 1]
768
+ expected_distribution = np.array([[1, 1, 1], [1, 1, 1]]).flatten()
769
+ np.testing.assert_array_equal(attention_metadata.request_distribution,
770
+ expected_distribution)
771
+
772
+ @patch('tpu_inference.runner.tpu_runner.NamedSharding')
773
+ @patch('tpu_inference.runner.tpu_runner.runner_utils')
774
+ @patch('tpu_inference.runner.tpu_runner.device_array',
775
+ side_effect=lambda mesh, tensors, **kwargs: tensors)
776
+ @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
777
+ def test_prepare_async_token_substitution_indices_dp(
778
+ self, mock_sampling_metadata, mock_device_array, mock_runner_utils,
779
+ mock_named_sharding):
780
+
781
+ # Setup test data
782
+ req_ids_dp = {0: ["req1", "req2"], 1: ["req3"]}
783
+ scheduled_tokens_per_dp_rank = {0: [3, 2], 1: [4]}
784
+ padded_num_scheduled_tokens_per_dp_rank = 8
785
+ dp_size = 2
786
+
787
+ # Setup _pre_async_results with placeholder mapping
788
+ self.runner._pre_async_results = MagicMock()
789
+ self.runner._pre_async_results.placeholder_req_id_to_index = {
790
+ "req1": 0,
791
+ "req3": 2
792
+ } # req2 is not a placeholder
793
+
794
+ # Call the method
795
+ result = self.runner._prepare_async_token_substitution_indices_dp(
796
+ req_ids_dp, scheduled_tokens_per_dp_rank,
797
+ padded_num_scheduled_tokens_per_dp_rank, dp_size)
798
+
799
+ token_in_tpu_cur_input_indices_dp, token_in_tpu_pre_next_tokens_indices_dp = result
800
+
801
+ # Verify DP rank 0
802
+ # req1: token_offset=0, acc_cur_len starts at 0, after 3 tokens: 3, so last token at 2
803
+ # req2: not a placeholder, should be skipped
804
+ assert token_in_tpu_cur_input_indices_dp[0] == [2]
805
+ assert token_in_tpu_pre_next_tokens_indices_dp[0] == [0]
806
+
807
+ # Verify DP rank 1
808
+ # req3: token_offset=8, acc_cur_len starts at 8, after 4 tokens: 12, so last token at 11
809
+ assert token_in_tpu_cur_input_indices_dp[1] == [11]
810
+ assert token_in_tpu_pre_next_tokens_indices_dp[1] == [2]
811
+
812
+ @patch('tpu_inference.runner.tpu_runner.NamedSharding')
813
+ @patch('tpu_inference.runner.tpu_runner.runner_utils')
814
+ @patch('tpu_inference.runner.tpu_runner.device_array',
815
+ side_effect=lambda mesh, tensors, **kwargs: tensors)
816
+ @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
817
+ def test_prepare_async_token_substitution_indices_dp_no_placeholders(
818
+ self, mock_sampling_metadata, mock_device_array, mock_runner_utils,
819
+ mock_named_sharding):
820
+ """Test when no requests are placeholders."""
821
+
822
+ req_ids_dp = {0: ["req1", "req2"], 1: ["req3"]}
823
+ scheduled_tokens_per_dp_rank = {0: [3, 2], 1: [4]}
824
+ padded_num_scheduled_tokens_per_dp_rank = 8
825
+ dp_size = 2
826
+
827
+ # No placeholders
828
+ self.runner._pre_async_results = MagicMock()
829
+ self.runner._pre_async_results.placeholder_req_id_to_index = {}
830
+
831
+ result = self.runner._prepare_async_token_substitution_indices_dp(
832
+ req_ids_dp, scheduled_tokens_per_dp_rank,
833
+ padded_num_scheduled_tokens_per_dp_rank, dp_size)
834
+
835
+ token_in_tpu_cur_input_indices_dp, token_in_tpu_pre_next_tokens_indices_dp = result
836
+
837
+ # All lists should be empty since no placeholders
838
+ assert token_in_tpu_cur_input_indices_dp[0] == []
839
+ assert token_in_tpu_pre_next_tokens_indices_dp[0] == []
840
+ assert token_in_tpu_cur_input_indices_dp[1] == []
841
+ assert token_in_tpu_pre_next_tokens_indices_dp[1] == []
842
+
843
+ def test_apply_async_token_substitution_empty_indices(self):
844
+ """Test _apply_async_token_substitution with empty indices (line 1025)."""
845
+
846
+ # Bind the actual method
847
+ self.runner._apply_async_token_substitution = TPUModelRunner._apply_async_token_substitution.__get__(
848
+ self.runner)
849
+
850
+ input_ids = np.array([1, 2, 3, 4, 5])
851
+ token_in_tpu_cur_input_indices = np.array([])
852
+ token_in_tpu_pre_next_tokens_indices = np.array([])
853
+
854
+ # Setup _pre_async_results
855
+ self.runner._pre_async_results = MagicMock()
856
+ self.runner._pre_async_results.next_tokens = np.array([10, 20, 30])
857
+ self.runner.mesh = MagicMock()
858
+
859
+ result = self.runner._apply_async_token_substitution(
860
+ input_ids, token_in_tpu_cur_input_indices,
861
+ token_in_tpu_pre_next_tokens_indices)
862
+
863
+ # Should return input_ids unchanged
864
+ np.testing.assert_array_equal(result, input_ids)
865
+
866
+ @patch('tpu_inference.runner.tpu_runner.device_array',
867
+ side_effect=lambda mesh, tensors, **kwargs: tensors)
868
+ def test_apply_async_token_substitution_with_padding(
869
+ self, mock_device_array):
870
+ """Test _apply_async_token_substitution with padding."""
871
+
872
+ # Bind the actual method
873
+ self.runner._apply_async_token_substitution = TPUModelRunner._apply_async_token_substitution.__get__(
874
+ self.runner)
875
+
876
+ input_ids = np.array([1, 2, 3, 4, 5, 6, 7, 8])
877
+ # Substitute positions 2 and 5
878
+ token_in_tpu_cur_input_indices = np.array([2, 5])
879
+ token_in_tpu_pre_next_tokens_indices = np.array([0, 1])
880
+
881
+ # Setup _pre_async_results
882
+ self.runner._pre_async_results = MagicMock()
883
+ self.runner._pre_async_results.next_tokens = np.array([100, 200, 300])
884
+ self.runner.mesh = MagicMock()
885
+ self.runner.maybe_forbid_compile = nullcontext()
886
+
887
+ # Mock the substitute function to verify it's called correctly
888
+ mock_substitute_fn = MagicMock(
889
+ return_value=np.array([1, 2, 100, 4, 5, 200, 7, 8]))
890
+ self.runner._substitute_placeholder_token_fn = mock_substitute_fn
891
+
892
+ _ = self.runner._apply_async_token_substitution(
893
+ input_ids, token_in_tpu_cur_input_indices,
894
+ token_in_tpu_pre_next_tokens_indices)
895
+
896
+ # Verify the substitute function was called
897
+ mock_substitute_fn.assert_called_once()
898
+ call_args = mock_substitute_fn.call_args[0]
899
+
900
+ # Verify input_ids
901
+ np.testing.assert_array_equal(call_args[0], input_ids)
902
+
903
+ # Verify padded indices length matches input_ids length
904
+ assert len(call_args[1]) == len(input_ids)
905
+ assert len(call_args[2]) == len(input_ids)
906
+
907
+ # Verify placeholder_num
908
+ assert call_args[4] == 2 # Number of actual substitutions
909
+
910
+ def test_prepare_inputs_routing_to_dp(self):
911
+ """Test _prepare_inputs routes to _prepare_inputs_dp when dp_size > 1."""
912
+
913
+ # Bind the actual _prepare_inputs method
914
+ self.runner._prepare_inputs = TPUModelRunner._prepare_inputs.__get__(
915
+ self.runner)
916
+
917
+ self.runner.dp_size = 2
918
+ self.runner._prepare_inputs_dp = MagicMock(return_value=(None, None,
919
+ None, None,
920
+ None, None))
921
+
922
+ scheduler_output = MagicMock()
923
+ self.runner._prepare_inputs(scheduler_output)
924
+
925
+ # Verify _prepare_inputs_dp was called
926
+ self.runner._prepare_inputs_dp.assert_called_once_with(
927
+ scheduler_output)
928
+
929
+ def test_prepare_inputs_routing_to_non_dp(self):
930
+ """Test _prepare_inputs routes to _prepare_inputs_non_dp when dp_size == 1."""
931
+
932
+ # Bind the actual _prepare_inputs method
933
+ self.runner._prepare_inputs = TPUModelRunner._prepare_inputs.__get__(
934
+ self.runner)
935
+
936
+ self.runner.dp_size = 1
937
+ self.runner._prepare_inputs_non_dp = MagicMock(
938
+ return_value=(None, None, None, None, None, None, None))
939
+
940
+ scheduler_output = MagicMock()
941
+ self.runner._prepare_inputs(scheduler_output)
942
+
943
+ # Verify _prepare_inputs_non_dp was called
944
+ self.runner._prepare_inputs_non_dp.assert_called_once_with(
945
+ scheduler_output)
946
+
947
+ @patch('tpu_inference.runner.tpu_runner.NamedSharding')
948
+ @patch('tpu_inference.runner.tpu_runner.runner_utils')
949
+ @patch('tpu_inference.runner.tpu_runner.device_array',
950
+ side_effect=lambda mesh, tensors, **kwargs: tensors)
951
+ @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
952
+ def test_prepare_inputs_dp_with_async_scheduling(self,
953
+ mock_sampling_metadata,
954
+ mock_device_array,
955
+ mock_runner_utils,
956
+ mock_named_sharding):
957
+
958
+ # Setup mocking
959
+ def mock_get_padded_token_len(paddings_list, val):
960
+ if val <= 2:
961
+ return 4
962
+ elif val <= 5:
963
+ return 8
964
+ else:
965
+ return 16
966
+
967
+ mock_runner_utils.get_padded_token_len.side_effect = mock_get_padded_token_len
968
+ mock_sampling_instance = MagicMock()
969
+ mock_sampling_metadata.from_input_batch.return_value = mock_sampling_instance
970
+ mock_named_sharding.return_value = MagicMock()
971
+
972
+ # Setup test data
973
+ num_scheduled_tokens = {"req1": 3, "req2": 2}
974
+ assigned_dp_ranks = {"req1": 0, "req2": 1}
975
+
976
+ self.runner.input_batch.num_reqs = 2
977
+ self.runner.input_batch.req_ids = ["req1", "req2"]
978
+ self.runner.input_batch.num_computed_tokens_cpu = np.array([4, 6])
979
+ self.runner.input_batch.token_ids_cpu = np.zeros((8, 64),
980
+ dtype=np.int32)
981
+
982
+ scheduler_output = self._create_mock_scheduler_output(
983
+ num_scheduled_tokens, assigned_dp_ranks)
984
+
985
+ # Enable async scheduling
986
+ self.runner.scheduler_config.async_scheduling = True
987
+ self.runner._pre_async_results = MagicMock()
988
+ self.runner._pre_async_results.placeholder_req_id_to_index = {
989
+ "req1": 0
990
+ }
991
+ self.runner._pre_async_results.next_tokens = np.array([100])
992
+
993
+ # Setup required attributes
994
+ self.runner.uses_mrope = False
995
+ self.runner.phase_based_profiler = None
996
+ self.runner.lora_config = None
997
+ self.runner.mesh = MagicMock()
998
+ self.runner.data_parallel_sharding = MagicMock()
999
+ self.runner.data_parallel_attn_sharding = MagicMock()
1000
+ self.runner.mm_manager = MagicMock()
1001
+ self.runner.speculative_decoding_manager = MagicMock()
1002
+ self.runner.lora_utils = MagicMock()
1003
+
1004
+ # Mock the token substitution preparation
1005
+ mock_prepare_async = MagicMock(return_value=({
1006
+ 0: [2],
1007
+ 1: []
1008
+ }, {
1009
+ 0: [0],
1010
+ 1: []
1011
+ }))
1012
+ self.runner._prepare_async_token_substitution_indices_dp = mock_prepare_async
1013
+
1014
+ # Execute the method
1015
+ _ = self.runner._prepare_inputs_dp(scheduler_output)
1016
+
1017
+ # Verify async token substitution was called
1018
+ mock_prepare_async.assert_called_once()
1019
+
1020
+ @patch('tpu_inference.runner.tpu_runner.NamedSharding')
1021
+ @patch('tpu_inference.runner.tpu_runner.runner_utils')
1022
+ @patch('tpu_inference.runner.tpu_runner.device_array',
1023
+ side_effect=lambda mesh, tensors, **kwargs: tensors)
1024
+ @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
1025
+ def test_prepare_inputs_dp_async_token_substitution_application(
1026
+ self, mock_sampling_metadata, mock_device_array, mock_runner_utils,
1027
+ mock_named_sharding):
1028
+ """Test async token substitution application in DP mode."""
1029
+
1030
+ # Setup mocking
1031
+ def mock_get_padded_token_len(paddings_list, val):
1032
+ if val <= 2:
1033
+ return 4
1034
+ elif val <= 5:
1035
+ return 8
1036
+ else:
1037
+ return 16
1038
+
1039
+ mock_runner_utils.get_padded_token_len.side_effect = mock_get_padded_token_len
1040
+ mock_sampling_instance = MagicMock()
1041
+ mock_sampling_metadata.from_input_batch.return_value = mock_sampling_instance
1042
+ mock_named_sharding.return_value = MagicMock()
1043
+
1044
+ # Setup test data
1045
+ num_scheduled_tokens = {"req1": 3, "req2": 2}
1046
+ assigned_dp_ranks = {"req1": 0, "req2": 1}
1047
+
1048
+ self.runner.input_batch.num_reqs = 2
1049
+ self.runner.input_batch.req_ids = ["req1", "req2"]
1050
+ self.runner.input_batch.num_computed_tokens_cpu = np.array([4, 6])
1051
+ self.runner.input_batch.token_ids_cpu = np.zeros((8, 64),
1052
+ dtype=np.int32)
1053
+
1054
+ scheduler_output = self._create_mock_scheduler_output(
1055
+ num_scheduled_tokens, assigned_dp_ranks)
1056
+
1057
+ # Enable async scheduling with placeholders
1058
+ self.runner.scheduler_config.async_scheduling = True
1059
+ self.runner._pre_async_results = MagicMock()
1060
+ self.runner._pre_async_results.placeholder_req_id_to_index = {
1061
+ "req1": 0,
1062
+ "req2": 1
1063
+ }
1064
+ self.runner._pre_async_results.next_tokens = np.array([100, 200])
1065
+
1066
+ # Setup required attributes
1067
+ self.runner.uses_mrope = False
1068
+ self.runner.phase_based_profiler = None
1069
+ self.runner.lora_config = None
1070
+ self.runner.mesh = MagicMock()
1071
+ self.runner.data_parallel_sharding = MagicMock()
1072
+ self.runner.data_parallel_attn_sharding = MagicMock()
1073
+ self.runner.mm_manager = MagicMock()
1074
+ self.runner.speculative_decoding_manager = MagicMock()
1075
+ self.runner.lora_utils = MagicMock()
1076
+
1077
+ # Mock the async token substitution application
1078
+ mock_apply_async = MagicMock(
1079
+ return_value=np.array([1, 2, 100, 4, 5, 200, 7, 8]))
1080
+ self.runner._apply_async_token_substitution = mock_apply_async
1081
+
1082
+ # Execute the method
1083
+ _ = self.runner._prepare_inputs_dp(scheduler_output)
1084
+
1085
+ # Verify _apply_async_token_substitution was called
1086
+ mock_apply_async.assert_called_once()
1087
+ call_args = mock_apply_async.call_args[0]
1088
+
1089
+ # Verify indices were concatenated from both DP ranks
1090
+ token_in_tpu_cur_input_indices = call_args[1]
1091
+ token_in_tpu_pre_next_tokens_indices = call_args[2]
1092
+
1093
+ # Should have indices from both ranks
1094
+ assert len(token_in_tpu_cur_input_indices) == 2
1095
+ assert len(token_in_tpu_pre_next_tokens_indices) == 2
1096
+
1097
+
1098
+ if __name__ == "__main__":
1099
+ pytest.main([__file__])