tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.2rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +89 -26
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +46 -17
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +44 -17
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/METADATA +7 -9
  240. tpu_inference-0.13.2rc3.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.2rc3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,429 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from unittest.mock import MagicMock, patch
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import numpy as np
20
+ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
21
+ SchedulerConfig, SpeculativeConfig, VllmConfig)
22
+ from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
23
+ from vllm.multimodal.inputs import (MultiModalBatchedField,
24
+ MultiModalFeatureSpec, MultiModalFieldElem,
25
+ MultiModalKwargsItem, PlaceholderRange)
26
+ from vllm.sampling_params import SamplingType
27
+ from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
28
+
29
+ from tpu_inference.runner.input_batch import CachedRequestState
30
+ from tpu_inference.runner.tpu_runner import TPUModelRunner
31
+
32
+
33
+ class TestMultiModalManager:
34
+
35
+ def setup_method(self):
36
+ # Mock JAX dependencies
37
+ self.mock_devices = [MagicMock(coords=i) for i in range(1)]
38
+ device_array = np.array(jax.devices()[:1]).reshape(1, 1, 1, 1)
39
+ self.mock_mesh = jax.make_mesh(device_array.shape,
40
+ ('data', 'attn_dp', 'expert', 'model'))
41
+ self.mock_rng_key = MagicMock()
42
+
43
+ with patch('jax.devices', return_value=self.mock_devices), \
44
+ patch('jax.make_mesh', return_value=self.mock_mesh), \
45
+ patch('jax.random.key', return_value=self.mock_rng_key), \
46
+ patch('tpu_inference.runner.tpu_runner.get_model', return_value=MagicMock()), \
47
+ patch('tpu_inference.runner.tpu_runner.make_optimized_mesh', return_value=self.mock_mesh):
48
+
49
+ model_config = ModelConfig(tokenizer_mode="auto",
50
+ trust_remote_code=False,
51
+ seed=0,
52
+ dtype='bfloat16')
53
+ cache_config = CacheConfig(
54
+ block_size=16,
55
+ gpu_memory_utilization=0.9,
56
+ swap_space=4,
57
+ cache_dtype="auto",
58
+ )
59
+ scheduler_config = SchedulerConfig(max_num_seqs=16,
60
+ max_model_len=1024,
61
+ is_encoder_decoder=False)
62
+ parallel_config = ParallelConfig(
63
+ pipeline_parallel_size=1,
64
+ tensor_parallel_size=1,
65
+ worker_use_ray=False,
66
+ )
67
+ speculative_config = SpeculativeConfig(
68
+ model='ngram',
69
+ num_speculative_tokens=5,
70
+ prompt_lookup_max=4,
71
+ )
72
+ vllm_config = VllmConfig(
73
+ model_config=model_config,
74
+ cache_config=cache_config,
75
+ scheduler_config=scheduler_config,
76
+ parallel_config=parallel_config,
77
+ speculative_config=speculative_config,
78
+ observability_config={},
79
+ additional_config={},
80
+ )
81
+
82
+ self.runner = TPUModelRunner(vllm_config,
83
+ devices=self.mock_devices)
84
+
85
+ def test_execute_mm_encoder_single_image(self):
86
+ import torch
87
+ """Tests _execute_mm_encoder with a single request and a single image."""
88
+ # 1. ===== Setup =====
89
+ self.runner.is_multimodal_model = True
90
+ self.mock_get_mm_embed_fn = MagicMock()
91
+ self.runner.embed_multimodal_fn = self.mock_get_mm_embed_fn
92
+
93
+ self.runner.state = MagicMock()
94
+ # Mock scheduler output
95
+ mock_scheduler_output = MagicMock(spec=VllmSchedulerOutput)
96
+ mock_scheduler_output.scheduled_encoder_inputs = {"req-1": [0]}
97
+
98
+ # Mock request state
99
+ dummy_pixel_values = torch.randn(3, 224, 224, dtype=torch.bfloat16)
100
+ dummy_grid_thw = torch.tensor([[1, 1, 1]], dtype=torch.int64)
101
+ mm_item = MultiModalKwargsItem.from_elems([
102
+ MultiModalFieldElem("image", "pixel_values", dummy_pixel_values,
103
+ MultiModalBatchedField()),
104
+ MultiModalFieldElem("image", "image_grid_thw", dummy_grid_thw,
105
+ MultiModalBatchedField())
106
+ ])
107
+
108
+ req_state = CachedRequestState(
109
+ req_id="req-1",
110
+ prompt_token_ids=[1, 2, 3],
111
+ output_token_ids=[],
112
+ sampling_params=MagicMock(),
113
+ block_ids=(),
114
+ num_computed_tokens=0,
115
+ mm_features=[
116
+ MultiModalFeatureSpec(data=mm_item,
117
+ identifier="req-1",
118
+ modality="image",
119
+ mm_position=PlaceholderRange(offset=0,
120
+ length=1))
121
+ ],
122
+ lora_request=None,
123
+ pooling_params=None,
124
+ generator=None,
125
+ )
126
+ self.runner.requests = {"req-1": req_state}
127
+
128
+ # Mock the return value of the multimodal encoder
129
+ dummy_embedding = jnp.ones((10, 128), dtype=jnp.bfloat16)
130
+ self.mock_get_mm_embed_fn.return_value = (dummy_embedding, )
131
+
132
+ # 2. ===== Act =====
133
+ self.runner.mm_manager.execute_mm_encoder(mock_scheduler_output)
134
+
135
+ # 3. ===== Assert =====
136
+ # Check if encoder_cache is populated correctly
137
+ assert "req-1" in self.runner.encoder_cache
138
+ cached_embedding = self.runner.encoder_cache["req-1"]
139
+ np.testing.assert_array_equal(np.asarray(cached_embedding),
140
+ np.asarray(dummy_embedding))
141
+
142
+ # Check if embed_multimodal_fn was called with correct args
143
+ self.mock_get_mm_embed_fn.assert_called_once()
144
+ call_args = self.mock_get_mm_embed_fn.call_args
145
+
146
+ # Positional args: (state, image_grid_thw)
147
+ state_arg, grid_arg = call_args.args
148
+ # Keyword args: **batched_mm_inputs
149
+ kwargs_arg = call_args.kwargs
150
+
151
+ assert state_arg == self.runner.state
152
+ assert grid_arg == ((1, 1, 1), )
153
+ assert "pixel_values" in kwargs_arg
154
+
155
+ # Verify the pixel values tensor passed to the mock
156
+ passed_pixel_values = kwargs_arg['pixel_values']
157
+ assert isinstance(passed_pixel_values, np.ndarray)
158
+ assert passed_pixel_values.dtype == jnp.bfloat16
159
+
160
+ # Convert torch tensor for comparison
161
+ expected_pixel_values = dummy_pixel_values.unsqueeze(0).to(
162
+ torch.float32).numpy().astype(jnp.bfloat16)
163
+ np.testing.assert_array_equal(np.asarray(passed_pixel_values),
164
+ expected_pixel_values)
165
+
166
+ def test_execute_mm_encoder_multiple_images(self):
167
+ import torch
168
+ """Tests _execute_mm_encoder with multiple requests and images."""
169
+ # 1. ===== Setup =====
170
+ self.runner.is_multimodal_model = True
171
+ self.mock_get_mm_embed_fn = MagicMock()
172
+ self.runner.embed_multimodal_fn = self.mock_get_mm_embed_fn
173
+
174
+ self.runner.state = MagicMock()
175
+ # Mock scheduler output for two requests
176
+ mock_scheduler_output = MagicMock(spec=VllmSchedulerOutput)
177
+ mock_scheduler_output.scheduled_encoder_inputs = {
178
+ "req-1": [0],
179
+ "req-2": [0]
180
+ }
181
+
182
+ # Mock request states
183
+ px_1 = torch.randn(3, 224, 224, dtype=torch.bfloat16)
184
+ grid_1 = torch.tensor([[1, 1, 1]], dtype=torch.int64)
185
+
186
+ mm_item_1 = MultiModalKwargsItem.from_elems([
187
+ MultiModalFieldElem("image", "pixel_values", px_1,
188
+ MultiModalBatchedField()),
189
+ MultiModalFieldElem("image", "image_grid_thw", grid_1,
190
+ MultiModalBatchedField())
191
+ ])
192
+
193
+ req_state_1 = CachedRequestState(
194
+ req_id="req-1",
195
+ prompt_token_ids=[],
196
+ output_token_ids=[],
197
+ sampling_params=MagicMock(),
198
+ block_ids=(),
199
+ num_computed_tokens=0,
200
+ mm_features=[
201
+ MultiModalFeatureSpec(data=mm_item_1,
202
+ identifier="req-1",
203
+ modality="image",
204
+ mm_position=PlaceholderRange(offset=0,
205
+ length=1))
206
+ ],
207
+ lora_request=None,
208
+ pooling_params=None,
209
+ generator=None)
210
+
211
+ px_2 = torch.randn(3, 224, 224, dtype=torch.bfloat16)
212
+ grid_2 = torch.tensor([[1, 2, 2]], dtype=torch.int64)
213
+ mm_item_2 = MultiModalKwargsItem.from_elems([
214
+ MultiModalFieldElem("image", "pixel_values", px_2,
215
+ MultiModalBatchedField()),
216
+ MultiModalFieldElem("image", "image_grid_thw", grid_2,
217
+ MultiModalBatchedField())
218
+ ])
219
+
220
+ req_state_2 = CachedRequestState(
221
+ req_id="req-2",
222
+ prompt_token_ids=[],
223
+ output_token_ids=[],
224
+ sampling_params=MagicMock(),
225
+ block_ids=(),
226
+ num_computed_tokens=0,
227
+ mm_features=[
228
+ MultiModalFeatureSpec(data=mm_item_2,
229
+ identifier="req-2",
230
+ modality="image",
231
+ mm_position=PlaceholderRange(offset=0,
232
+ length=1))
233
+ ],
234
+ lora_request=None,
235
+ pooling_params=None,
236
+ generator=None)
237
+
238
+ self.runner.requests = {"req-1": req_state_1, "req-2": req_state_2}
239
+
240
+ emb_1 = jnp.ones((10, 128), dtype=jnp.bfloat16)
241
+ emb_2 = jnp.ones((20, 128), dtype=jnp.bfloat16) * 2
242
+ self.mock_get_mm_embed_fn.return_value = (emb_1, emb_2)
243
+
244
+ # 2. ===== Act =====
245
+ self.runner.mm_manager.execute_mm_encoder(mock_scheduler_output)
246
+
247
+ # 3. ===== Assert =====
248
+ assert "req-1" in self.runner.encoder_cache
249
+ np.testing.assert_array_equal(
250
+ np.asarray(self.runner.encoder_cache["req-1"]), np.asarray(emb_1))
251
+ assert "req-2" in self.runner.encoder_cache
252
+ np.testing.assert_array_equal(
253
+ np.asarray(self.runner.encoder_cache["req-2"]), np.asarray(emb_2))
254
+
255
+ self.mock_get_mm_embed_fn.assert_called_once()
256
+ call_args = self.mock_get_mm_embed_fn.call_args
257
+
258
+ state_arg, grid_arg = call_args.args
259
+ kwargs_arg = call_args.kwargs
260
+
261
+ assert state_arg == self.runner.state
262
+ assert grid_arg == ((1, 1, 1), (1, 2, 2))
263
+ assert "pixel_values" in kwargs_arg
264
+
265
+ passed_pixel_values = kwargs_arg['pixel_values']
266
+ assert passed_pixel_values.shape == (2, 3, 224, 224)
267
+
268
+ expected_pixel_values = torch.stack([px_1, px_2], dim=0).to(
269
+ torch.float32).numpy().astype(jnp.bfloat16)
270
+ np.testing.assert_array_equal(np.asarray(passed_pixel_values),
271
+ expected_pixel_values)
272
+
273
+ def test_gather_mm_embeddings_chunked_prefill(self):
274
+ """Tests _gather_mm_embeddings with chunked prefill scenarios."""
275
+ # 1. ===== Setup =====
276
+ self.runner.is_multimodal_model = True
277
+ req_id = "req-1"
278
+
279
+ # Mock encoder output
280
+ encoder_embedding = jnp.arange(56 * 128, dtype=jnp.bfloat16).reshape(
281
+ (56, 128))
282
+ self.runner.encoder_cache = {req_id: encoder_embedding}
283
+
284
+ mock_sampling_params = MagicMock()
285
+ mock_sampling_params.sampling_type = SamplingType.GREEDY
286
+ mock_sampling_params.top_k = -1
287
+ mock_sampling_params.top_p = 1.0
288
+ mock_sampling_params.temperature = 0.0
289
+ mock_sampling_params.min_tokens = 0
290
+ mock_sampling_params.logprobs = None
291
+ mock_sampling_params.logit_bias = None
292
+ mock_sampling_params.allowed_token_ids = set()
293
+ mock_sampling_params.bad_words_token_ids = None
294
+ mock_sampling_params.all_stop_token_ids = set()
295
+
296
+ # Mock request state
297
+ req_state = CachedRequestState(
298
+ req_id=req_id,
299
+ prompt_token_ids=list(range(100)),
300
+ output_token_ids=[],
301
+ sampling_params=mock_sampling_params,
302
+ block_ids=([], ),
303
+ num_computed_tokens=0, # This will be updated per step
304
+ mm_features=[
305
+ MultiModalFeatureSpec(data=None,
306
+ identifier=req_id,
307
+ modality="image",
308
+ mm_position=PlaceholderRange(offset=10,
309
+ length=56))
310
+ ],
311
+ lora_request=None,
312
+ pooling_params=None,
313
+ generator=None,
314
+ )
315
+ self.runner.requests = {req_id: req_state}
316
+ self.runner.input_batch.add_request(req_state)
317
+
318
+ # 2. ===== Act & Assert =====
319
+
320
+ # ----- Step 1: First chunk of prefill -----
321
+ req_state.num_computed_tokens = 0
322
+ mock_scheduler_output_1 = MagicMock(spec=VllmSchedulerOutput)
323
+ mock_scheduler_output_1.num_scheduled_tokens = {req_id: 20}
324
+
325
+ gathered_embeds_1 = self.runner.mm_manager.gather_mm_embeddings(
326
+ mock_scheduler_output_1, target_pad_len=10)
327
+
328
+ expected_embeds_1 = encoder_embedding[0:10]
329
+ assert gathered_embeds_1.shape == expected_embeds_1.shape
330
+ np.testing.assert_array_equal(np.asarray(gathered_embeds_1),
331
+ np.asarray(expected_embeds_1))
332
+
333
+ # ----- Step 2: Middle chunk of prefill -----
334
+ req_state.num_computed_tokens = 20
335
+ mock_scheduler_output_2 = MagicMock(spec=VllmSchedulerOutput)
336
+ mock_scheduler_output_2.num_scheduled_tokens = {req_id: 30}
337
+
338
+ gathered_embeds_2 = self.runner.mm_manager.gather_mm_embeddings(
339
+ mock_scheduler_output_2, target_pad_len=30)
340
+
341
+ expected_embeds_2 = encoder_embedding[10:40]
342
+ assert gathered_embeds_2.shape == expected_embeds_2.shape
343
+ np.testing.assert_array_equal(np.asarray(gathered_embeds_2),
344
+ np.asarray(expected_embeds_2))
345
+
346
+ # ----- Step 3: Last chunk of prefill -----
347
+ req_state.num_computed_tokens = 50
348
+ mock_scheduler_output_3 = MagicMock(spec=VllmSchedulerOutput)
349
+ mock_scheduler_output_3.num_scheduled_tokens = {req_id: 30}
350
+
351
+ gathered_embeds_3 = self.runner.mm_manager.gather_mm_embeddings(
352
+ mock_scheduler_output_3, target_pad_len=16)
353
+
354
+ expected_embeds_3 = encoder_embedding[40:56]
355
+ assert gathered_embeds_3.shape == expected_embeds_3.shape
356
+ np.testing.assert_array_equal(np.asarray(gathered_embeds_3),
357
+ np.asarray(expected_embeds_3))
358
+
359
+ def test_calc_mrope_positions(self):
360
+ """Tests the calculation of M-RoPE positions for mixed prompt/completion."""
361
+ # 1. ===== Setup =====
362
+ self.runner.uses_mrope = True
363
+ req_id = "req-1"
364
+ prompt_len = 20
365
+ num_computed = 15
366
+ num_scheduled = 10
367
+ mrope_delta = 100
368
+
369
+ # Mock request state with pre-computed mrope positions for the prompt
370
+ mock_mrope_positions = np.arange(3 * prompt_len,
371
+ dtype=np.int64).reshape(
372
+ 3, prompt_len)
373
+ mock_sampling_params = MagicMock()
374
+ mock_sampling_params.sampling_type = SamplingType.GREEDY
375
+ mock_sampling_params.top_k = -1
376
+ mock_sampling_params.top_p = 1.0
377
+ mock_sampling_params.temperature = 0.0
378
+ mock_sampling_params.min_tokens = 0
379
+ mock_sampling_params.logprobs = None
380
+ mock_sampling_params.logit_bias = None
381
+ mock_sampling_params.allowed_token_ids = set()
382
+ mock_sampling_params.bad_words_token_ids = None
383
+ mock_sampling_params.all_stop_token_ids = set()
384
+
385
+ req_state = CachedRequestState(
386
+ req_id=req_id,
387
+ prompt_token_ids=list(range(prompt_len)),
388
+ output_token_ids=[],
389
+ sampling_params=mock_sampling_params,
390
+ block_ids=([], ),
391
+ num_computed_tokens=num_computed,
392
+ mm_features=[],
393
+ lora_request=None,
394
+ pooling_params=None,
395
+ generator=None,
396
+ mrope_positions=mock_mrope_positions,
397
+ mrope_position_delta=mrope_delta,
398
+ )
399
+ self.runner.requests = {req_id: req_state}
400
+ self.runner.input_batch.add_request(req_state)
401
+ # Manually set num_computed_tokens in the batch as add_request sets it to 0
402
+ self.runner.input_batch.num_computed_tokens_cpu[0] = num_computed
403
+
404
+ # Mock scheduler output
405
+ mock_scheduler_output = MagicMock(spec=VllmSchedulerOutput)
406
+ mock_scheduler_output.num_scheduled_tokens = {req_id: num_scheduled}
407
+
408
+ # Patch the static method that computes completion positions
409
+ with patch.object(MRotaryEmbedding,
410
+ "get_next_input_positions_tensor") as mock_get_next:
411
+ # 2. ===== Act =====
412
+ self.runner.mm_manager.calc_mrope_positions(mock_scheduler_output)
413
+
414
+ # 3. ===== Assert =====
415
+ # The first 5 positions should be copied from the pre-computed prompt positions
416
+ expected_prompt_part = mock_mrope_positions[:, 15:20]
417
+ actual_prompt_part = self.runner.mrope_positions_cpu[:, 0:5]
418
+ np.testing.assert_array_equal(actual_prompt_part,
419
+ expected_prompt_part)
420
+
421
+ # The next 5 positions should be computed on-the-fly
422
+ mock_get_next.assert_called_once()
423
+ call_kwargs = mock_get_next.call_args.kwargs
424
+ np.testing.assert_array_equal(call_kwargs["out"],
425
+ self.runner.mrope_positions_cpu)
426
+ assert call_kwargs["out_offset"] == 5
427
+ assert call_kwargs["mrope_position_delta"] == mrope_delta
428
+ assert call_kwargs["context_len"] == prompt_len
429
+ assert call_kwargs["num_new_tokens"] == 5
@@ -0,0 +1,84 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+ from unittest.mock import MagicMock
17
+
18
+ import numpy as np
19
+
20
+ from tpu_inference.runner.persistent_batch_manager import \
21
+ PersistentBatchManager
22
+
23
+
24
+ class TestPersistentBatchManager(unittest.TestCase):
25
+
26
+ def test_update_states_pp_non_last_rank(self):
27
+ """
28
+ the current rank is not the last rank.
29
+
30
+ This test verifies that when new tokens are received from the scheduler,
31
+ the internal state of the PersistentBatchManager (including request
32
+ states and the input batch) is correctly updated.
33
+ """
34
+
35
+ req_id = 101
36
+ initial_output_tokens = [10, 20]
37
+
38
+ req_state = MagicMock()
39
+ req_state.num_tokens = 2
40
+ req_state.output_token_ids = list(initial_output_tokens)
41
+
42
+ requests = {req_id: req_state}
43
+
44
+ input_batch = MagicMock()
45
+ input_batch.req_id_to_index = {req_id: 0}
46
+ input_batch.num_prompt_tokens = np.array([2], dtype=np.int32)
47
+ input_batch.token_ids_cpu = np.zeros((1, 10), dtype=np.int32)
48
+ input_batch.num_tokens = np.array([2], dtype=np.int32)
49
+ input_batch.num_tokens_no_spec = np.array([2], dtype=np.int32)
50
+ input_batch.num_reqs = 1
51
+
52
+ encoder_cache = MagicMock()
53
+ model_config = MagicMock()
54
+
55
+ manager = PersistentBatchManager(requests,
56
+ input_batch,
57
+ encoder_cache,
58
+ False,
59
+ model_config,
60
+ is_last_rank=False)
61
+
62
+ scheduler_output = MagicMock()
63
+ req_data = MagicMock()
64
+ req_data.req_ids = [req_id]
65
+ req_data.num_computed_tokens = [2]
66
+ new_token_id = [30]
67
+ req_data.new_token_ids = [new_token_id]
68
+ req_data.new_block_ids = [None]
69
+ req_data.resumed_from_preemption = [False]
70
+ req_data.num_output_tokens = [len(initial_output_tokens) + 1]
71
+ scheduler_output.scheduled_cached_reqs = req_data
72
+ scheduler_output.scheduled_spec_decode_tokens = {}
73
+
74
+ manager.update_states(scheduler_output, None)
75
+
76
+ expected_output_token_ids = initial_output_tokens + new_token_id
77
+ self.assertEqual(req_state.output_token_ids, expected_output_token_ids)
78
+
79
+ np.testing.assert_array_equal(
80
+ manager.input_batch.token_ids_cpu[0, 2:3],
81
+ np.array(new_token_id, dtype=np.int32))
82
+
83
+ self.assertEqual(manager.input_batch.num_tokens[0], 3)
84
+ self.assertEqual(manager.input_batch.num_tokens_no_spec[0], 3)