tpu-inference 0.11.1.dev202512030818__py3-none-any.whl → 0.13.0rc2.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (250) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_disagg_utils.py +14 -0
  4. tests/core/test_dp_scheduler.py +650 -768
  5. tests/core/test_init.py +14 -0
  6. tests/distributed/__init__.py +13 -0
  7. tests/distributed/test_distributed_utils.py +120 -0
  8. tests/distributed/test_tpu_connector.py +478 -0
  9. tests/e2e/__init__.py +13 -0
  10. tests/e2e/test_async_scheduler.py +211 -0
  11. tests/e2e/test_data_parallel.py +289 -0
  12. tests/e2e/test_hybrid_kvcache.py +219 -0
  13. tests/e2e/test_local_disagg.py +257 -0
  14. tests/e2e/test_model_loader.py +268 -0
  15. tests/e2e/test_multi_modal_inference.py +111 -0
  16. tests/e2e/test_pipeline_parallel.py +265 -0
  17. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  18. tests/e2e/test_sampling_params.py +269 -0
  19. tests/e2e/test_speculative_decoding.py +311 -0
  20. tests/e2e/test_structured_decoding.py +46 -0
  21. tests/executors/__init__.py +13 -0
  22. tests/executors/test_ray_distributed_executor.py +199 -0
  23. tests/experimental/__init__.py +13 -0
  24. tests/experimental/test_llama3_jax_stashed.py +208 -0
  25. tests/kernels/__init__.py +13 -0
  26. tests/kernels/collectives/__init__.py +13 -0
  27. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  28. tests/kernels/fused_moe_v1_test.py +14 -0
  29. tests/kernels/gmm_test.py +205 -0
  30. tests/kernels/mla_v1_test.py +143 -41
  31. tests/kernels/quantized_matmul_kernel_test.py +2 -34
  32. tests/kernels/ragged_kv_cache_update_v2_test.py +14 -0
  33. tests/kernels/ragged_paged_attention_kernel_v2_test.py +14 -0
  34. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +17 -1
  35. tests/kernels/ragged_paged_attention_kernel_v3_test.py +17 -1
  36. tests/layers/__init__.py +13 -0
  37. tests/layers/common/__init__.py +13 -0
  38. tests/layers/common/test_attention_interface.py +156 -0
  39. tests/layers/common/test_quantization.py +149 -0
  40. tests/layers/jax/__init__.py +13 -0
  41. tests/layers/jax/attention/__init__.py +13 -0
  42. tests/layers/jax/attention/test_common_attention.py +103 -0
  43. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  44. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  45. tests/layers/jax/moe/__init__.py +13 -0
  46. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  47. tests/layers/jax/sample/__init__.py +13 -0
  48. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  49. tests/layers/jax/sample/test_sampling.py +115 -0
  50. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  51. tests/layers/jax/test_layers.py +155 -0
  52. tests/{test_quantization.py → layers/jax/test_qwix.py} +183 -50
  53. tests/layers/jax/test_rope.py +93 -0
  54. tests/layers/jax/test_sharding.py +159 -0
  55. tests/layers/jax/test_transformer_block.py +152 -0
  56. tests/layers/vllm/__init__.py +13 -0
  57. tests/layers/vllm/test_attention.py +363 -0
  58. tests/layers/vllm/test_awq.py +405 -0
  59. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  60. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +418 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +441 -0
  62. tests/layers/vllm/test_fp8.py +17 -0
  63. tests/layers/vllm/test_mxfp4.py +312 -0
  64. tests/layers/vllm/test_unquantized.py +651 -0
  65. tests/layers/vllm/utils.py +87 -0
  66. tests/lora/__init__.py +13 -0
  67. tests/lora/conftest.py +14 -0
  68. tests/lora/test_bgmv.py +14 -0
  69. tests/lora/test_layers.py +21 -3
  70. tests/lora/test_lora.py +15 -1
  71. tests/lora/test_lora_perf.py +67 -0
  72. tests/models/__init__.py +13 -0
  73. tests/models/common/__init__.py +13 -0
  74. tests/models/common/test_model_loader.py +455 -0
  75. tests/models/jax/__init__.py +13 -0
  76. tests/models/jax/test_deepseek_v3.py +401 -0
  77. tests/models/jax/test_llama3.py +184 -0
  78. tests/models/jax/test_llama4.py +298 -0
  79. tests/models/jax/test_llama_eagle3.py +197 -0
  80. tests/models/jax/test_llama_guard_4.py +242 -0
  81. tests/models/jax/test_qwen2.py +172 -0
  82. tests/models/jax/test_qwen2_5_vl.py +605 -0
  83. tests/models/jax/test_qwen3.py +169 -0
  84. tests/models/jax/test_weight_loading.py +180 -0
  85. tests/models/jax/utils/__init__.py +13 -0
  86. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  87. tests/platforms/__init__.py +13 -0
  88. tests/platforms/test_tpu_platform.py +54 -0
  89. tests/runner/__init__.py +13 -0
  90. tests/runner/test_block_table.py +395 -0
  91. tests/runner/test_input_batch.py +226 -0
  92. tests/runner/test_kv_cache.py +220 -0
  93. tests/runner/test_kv_cache_manager.py +498 -0
  94. tests/runner/test_multimodal_manager.py +429 -0
  95. tests/runner/test_persistent_batch_manager.py +84 -0
  96. tests/runner/test_speculative_decoding_manager.py +368 -0
  97. tests/runner/test_structured_decoding_manager.py +220 -0
  98. tests/runner/test_tpu_runner.py +261 -0
  99. tests/runner/test_tpu_runner_dp.py +1099 -0
  100. tests/runner/test_tpu_runner_mesh.py +200 -0
  101. tests/runner/test_utils.py +411 -0
  102. tests/spec_decode/__init__.py +13 -0
  103. tests/spec_decode/test_eagle3.py +311 -0
  104. tests/test_base.py +14 -0
  105. tests/test_envs.py +78 -1
  106. tests/test_tpu_info.py +14 -0
  107. tests/test_utils.py +1 -43
  108. tests/worker/__init__.py +13 -0
  109. tests/worker/tpu_worker_test.py +414 -0
  110. tpu_inference/__init__.py +14 -0
  111. tpu_inference/core/__init__.py +13 -0
  112. tpu_inference/core/sched/__init__.py +13 -0
  113. tpu_inference/core/sched/dp_scheduler.py +372 -56
  114. tpu_inference/distributed/__init__.py +13 -0
  115. tpu_inference/distributed/jax_parallel_state.py +14 -0
  116. tpu_inference/distributed/tpu_connector.py +14 -9
  117. tpu_inference/distributed/utils.py +56 -4
  118. tpu_inference/envs.py +38 -7
  119. tpu_inference/executors/__init__.py +13 -0
  120. tpu_inference/executors/ray_distributed_executor.py +17 -0
  121. tpu_inference/experimental/__init__.py +13 -0
  122. tpu_inference/experimental/llama3_jax_stashed.py +14 -0
  123. tpu_inference/kernels/__init__.py +13 -0
  124. tpu_inference/kernels/collectives/__init__.py +13 -0
  125. tpu_inference/kernels/collectives/all_gather_matmul.py +12 -6
  126. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +7 -2
  127. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  128. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  129. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  130. tpu_inference/kernels/fused_moe/v1/kernel.py +370 -324
  131. tpu_inference/kernels/megablox/__init__.py +13 -0
  132. tpu_inference/kernels/megablox/common.py +54 -0
  133. tpu_inference/kernels/megablox/gmm.py +646 -0
  134. tpu_inference/kernels/mla/__init__.py +13 -0
  135. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  136. tpu_inference/kernels/mla/v1/kernel.py +117 -145
  137. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  138. tpu_inference/kernels/quantized_matmul/kernel.py +69 -8
  139. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  140. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  141. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +2 -1
  142. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +2 -1
  143. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  144. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +194 -101
  145. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +95 -78
  146. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3817 -3504
  147. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +376 -195
  148. tpu_inference/kernels/ragged_paged_attention/v3/util.py +15 -1
  149. tpu_inference/layers/__init__.py +13 -0
  150. tpu_inference/layers/common/__init__.py +13 -0
  151. tpu_inference/layers/common/attention_interface.py +26 -19
  152. tpu_inference/layers/common/attention_metadata.py +14 -0
  153. tpu_inference/layers/common/quant_methods.py +15 -0
  154. tpu_inference/layers/common/quantization.py +270 -0
  155. tpu_inference/layers/common/sharding.py +28 -5
  156. tpu_inference/layers/jax/__init__.py +13 -0
  157. tpu_inference/layers/jax/attention/__init__.py +13 -0
  158. tpu_inference/layers/jax/attention/attention.py +19 -6
  159. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +270 -77
  160. tpu_inference/layers/jax/attention/gpt_oss_attention.py +24 -11
  161. tpu_inference/layers/jax/attention/llama4_attention.py +17 -4
  162. tpu_inference/layers/jax/base.py +14 -0
  163. tpu_inference/layers/jax/constants.py +13 -0
  164. tpu_inference/layers/jax/layers.py +14 -0
  165. tpu_inference/layers/jax/misc.py +14 -0
  166. tpu_inference/layers/jax/moe/__init__.py +13 -0
  167. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +20 -13
  168. tpu_inference/layers/jax/moe/gpt_oss_moe.py +14 -0
  169. tpu_inference/layers/jax/moe/moe.py +43 -3
  170. tpu_inference/layers/jax/pp_utils.py +53 -0
  171. tpu_inference/layers/jax/rope.py +14 -0
  172. tpu_inference/layers/jax/rope_interface.py +14 -0
  173. tpu_inference/layers/jax/sample/__init__.py +13 -0
  174. tpu_inference/layers/jax/sample/rejection_sampler.py +13 -0
  175. tpu_inference/layers/jax/sample/sampling.py +15 -1
  176. tpu_inference/layers/jax/sample/sampling_metadata.py +14 -0
  177. tpu_inference/layers/jax/transformer_block.py +14 -0
  178. tpu_inference/layers/vllm/__init__.py +13 -0
  179. tpu_inference/layers/vllm/attention.py +4 -4
  180. tpu_inference/layers/vllm/fused_moe.py +210 -260
  181. tpu_inference/layers/vllm/linear_common.py +57 -22
  182. tpu_inference/layers/vllm/quantization/__init__.py +16 -0
  183. tpu_inference/layers/vllm/quantization/awq.py +15 -1
  184. tpu_inference/layers/vllm/quantization/common.py +33 -18
  185. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  186. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +18 -3
  187. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +211 -148
  188. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  189. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +14 -0
  190. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +14 -0
  191. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  192. tpu_inference/layers/vllm/quantization/mxfp4.py +278 -209
  193. tpu_inference/layers/vllm/quantization/unquantized.py +134 -86
  194. tpu_inference/layers/vllm/sharding.py +21 -4
  195. tpu_inference/lora/__init__.py +13 -0
  196. tpu_inference/lora/torch_lora_ops.py +8 -13
  197. tpu_inference/models/__init__.py +13 -0
  198. tpu_inference/models/common/__init__.py +13 -0
  199. tpu_inference/models/common/model_loader.py +74 -35
  200. tpu_inference/models/jax/__init__.py +13 -0
  201. tpu_inference/models/jax/deepseek_v3.py +267 -157
  202. tpu_inference/models/jax/gpt_oss.py +26 -10
  203. tpu_inference/models/jax/jax_intermediate_tensor.py +14 -0
  204. tpu_inference/models/jax/llama3.py +99 -36
  205. tpu_inference/models/jax/llama4.py +14 -0
  206. tpu_inference/models/jax/llama_eagle3.py +14 -0
  207. tpu_inference/models/jax/llama_guard_4.py +15 -1
  208. tpu_inference/models/jax/qwen2.py +17 -2
  209. tpu_inference/models/jax/qwen2_5_vl.py +18 -4
  210. tpu_inference/models/jax/qwen3.py +17 -2
  211. tpu_inference/models/jax/utils/__init__.py +13 -0
  212. tpu_inference/models/jax/utils/file_utils.py +14 -0
  213. tpu_inference/models/jax/utils/multi_modal_utils.py +18 -4
  214. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  215. tpu_inference/models/jax/utils/{quantization/quantization_utils.py → qwix/qwix_utils.py} +88 -25
  216. tpu_inference/models/jax/utils/weight_utils.py +39 -2
  217. tpu_inference/models/vllm/__init__.py +13 -0
  218. tpu_inference/models/vllm/vllm_model_wrapper.py +20 -3
  219. tpu_inference/models/vllm/vllm_model_wrapper_context.py +14 -0
  220. tpu_inference/platforms/__init__.py +14 -0
  221. tpu_inference/platforms/tpu_platform.py +47 -64
  222. tpu_inference/runner/__init__.py +13 -0
  223. tpu_inference/runner/compilation_manager.py +72 -37
  224. tpu_inference/runner/kv_cache.py +54 -20
  225. tpu_inference/runner/kv_cache_manager.py +45 -15
  226. tpu_inference/runner/lora_utils.py +14 -0
  227. tpu_inference/runner/multimodal_manager.py +15 -1
  228. tpu_inference/runner/persistent_batch_manager.py +14 -0
  229. tpu_inference/runner/speculative_decoding_manager.py +14 -0
  230. tpu_inference/runner/structured_decoding_manager.py +14 -0
  231. tpu_inference/runner/tpu_runner.py +41 -16
  232. tpu_inference/spec_decode/__init__.py +13 -0
  233. tpu_inference/spec_decode/jax/__init__.py +13 -0
  234. tpu_inference/spec_decode/jax/eagle3.py +13 -0
  235. tpu_inference/tpu_info.py +14 -0
  236. tpu_inference/utils.py +42 -36
  237. tpu_inference/worker/__init__.py +13 -0
  238. tpu_inference/worker/tpu_worker.py +63 -50
  239. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/METADATA +11 -9
  240. tpu_inference-0.13.0rc2.post7.dist-info/RECORD +261 -0
  241. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  242. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +0 -5
  243. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +0 -6
  244. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +0 -5
  245. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +0 -6
  246. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +0 -105
  247. tpu_inference-0.11.1.dev202512030818.dist-info/RECORD +0 -174
  248. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/WHEEL +0 -0
  249. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/licenses/LICENSE +0 -0
  250. {tpu_inference-0.11.1.dev202512030818.dist-info → tpu_inference-0.13.0rc2.post7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,261 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from unittest.mock import MagicMock, patch
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ import numpy as np
20
+ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
21
+ SchedulerConfig, SpeculativeConfig, VllmConfig)
22
+
23
+ from tpu_inference.runner.tpu_runner import TPUModelRunner
24
+
25
+
26
+ class TestTPUJaxRunner:
27
+
28
+ def setup_method(self):
29
+ # Mock JAX dependencies
30
+ self.mock_devices = [MagicMock(coords=i) for i in range(1)]
31
+ self.mock_rng_key = MagicMock()
32
+ device_array = np.array(jax.devices()[:1]).reshape(1, 1, 1, -1)
33
+ self.mock_mesh = jax.make_mesh(device_array.shape,
34
+ ('data', 'attn_dp', 'expert', 'model'))
35
+ with patch('jax.devices', return_value=self.mock_devices), \
36
+ patch('jax.make_mesh', return_value=self.mock_mesh), \
37
+ patch('jax.random.key', return_value=self.mock_rng_key), \
38
+ patch('tpu_inference.runner.tpu_runner.get_model', return_value=MagicMock()), \
39
+ patch('tpu_inference.runner.tpu_runner.make_optimized_mesh', return_value=self.mock_mesh):
40
+
41
+ model_config = ModelConfig(tokenizer_mode="auto",
42
+ trust_remote_code=False,
43
+ seed=0,
44
+ dtype='bfloat16')
45
+ cache_config = CacheConfig(
46
+ block_size=16,
47
+ gpu_memory_utilization=0.9,
48
+ swap_space=4,
49
+ cache_dtype="auto",
50
+ )
51
+ scheduler_config = SchedulerConfig(max_num_seqs=16,
52
+ max_model_len=1024,
53
+ is_encoder_decoder=False)
54
+ parallel_config = ParallelConfig(
55
+ pipeline_parallel_size=1,
56
+ tensor_parallel_size=1,
57
+ worker_use_ray=False,
58
+ )
59
+ speculative_config = SpeculativeConfig(
60
+ model='ngram',
61
+ num_speculative_tokens=5,
62
+ prompt_lookup_max=4,
63
+ )
64
+ vllm_config = VllmConfig(
65
+ model_config=model_config,
66
+ cache_config=cache_config,
67
+ scheduler_config=scheduler_config,
68
+ parallel_config=parallel_config,
69
+ speculative_config=speculative_config,
70
+ observability_config={},
71
+ additional_config={},
72
+ )
73
+
74
+ self.runner = TPUModelRunner(vllm_config,
75
+ devices=self.mock_devices)
76
+
77
+ def test_get_supported_tasks_runner(self):
78
+ """Test get_supported_tasks for generate runner type."""
79
+ supported_tasks = self.runner.get_supported_tasks()
80
+ assert supported_tasks == ("generate", )
81
+
82
+ def test_get_input_ids_embeds(self):
83
+ """Tests _get_input_ids_embeds for both multimodal and text-only models."""
84
+ # 1. ===== Setup =====
85
+ dummy_input_ids = jnp.array([1, 2, 3])
86
+ dummy_mm_embeds = jnp.ones((10, 128))
87
+ dummy_final_embeds = jnp.ones((3, 128))
88
+
89
+ # Mock the embedding function
90
+ self.mock_get_input_embed_fn = MagicMock()
91
+ self.runner.embed_input_ids_fn = self.mock_get_input_embed_fn
92
+ self.mock_get_input_embed_fn.return_value = dummy_final_embeds
93
+ self.runner.state = MagicMock()
94
+
95
+ # 2. ===== Act & Assert (Multimodal) =====
96
+ self.runner.is_multimodal_model = True
97
+
98
+ input_ids_res, inputs_embeds_res = self.runner._get_input_ids_embeds(
99
+ dummy_input_ids, dummy_mm_embeds)
100
+
101
+ assert input_ids_res is None
102
+ np.testing.assert_array_equal(np.asarray(inputs_embeds_res),
103
+ np.asarray(dummy_final_embeds))
104
+ self.mock_get_input_embed_fn.assert_called_once_with(
105
+ self.runner.state, dummy_input_ids, dummy_mm_embeds)
106
+
107
+ # 3. ===== Act & Assert (Text-only) =====
108
+ self.mock_get_input_embed_fn.reset_mock()
109
+ self.runner.is_multimodal_model = False
110
+
111
+ input_ids_res, inputs_embeds_res = self.runner._get_input_ids_embeds(
112
+ dummy_input_ids, dummy_mm_embeds)
113
+
114
+ assert inputs_embeds_res is None
115
+ np.testing.assert_array_equal(np.asarray(input_ids_res),
116
+ np.asarray(dummy_input_ids))
117
+ self.mock_get_input_embed_fn.assert_not_called()
118
+
119
+ @patch('tpu_inference.runner.tpu_runner.TPUSupportedSamplingMetadata')
120
+ def test_prepare_inputs_hybrid_kvcache(self, mock_sampling_metadata):
121
+ # create hybrid kv cache config
122
+ # 20 layers, 10 full attn + 10 sw attn
123
+ self._create_mock_hybrid_kv_cache_config()
124
+
125
+ # Mock scheduler output.
126
+ scheduler_output = MagicMock()
127
+ scheduler_output.total_num_scheduled_tokens = 10
128
+ scheduler_output.num_scheduled_tokens = {'req1': 10}
129
+ scheduler_output.scheduled_spec_decode_tokens = {}
130
+ scheduler_output.grammar_bitmask = None
131
+
132
+ # Mock input_batch
133
+ self.runner.input_batch = MagicMock()
134
+ self.runner.input_batch.num_reqs = 1
135
+ self.runner.input_batch.req_ids = ['req1']
136
+ self.runner.input_batch.req_id_to_index = {'req1': 0}
137
+ self.runner.input_batch.num_computed_tokens_cpu = np.array([10])
138
+ self.runner.input_batch.token_ids_cpu = np.random.randint(
139
+ 0, 1000, (8, 64), dtype=np.int32)
140
+
141
+ # Mock block tables
142
+ # there will be 2 block tables since there are 2 kv cache groups
143
+ mock_block_table = MagicMock()
144
+ mock_block_table.get_cpu_tensor.return_value = np.zeros(
145
+ self.runner.block_tables_cpu[0].shape)
146
+ self.runner.input_batch.block_table = [
147
+ mock_block_table, mock_block_table
148
+ ]
149
+ self.runner.block_tables_cpu = [
150
+ np.zeros(self.runner.block_tables_cpu[0].shape, dtype=np.int32),
151
+ np.zeros(self.runner.block_tables_cpu[0].shape, dtype=np.int32)
152
+ ]
153
+
154
+ mock_sampling_instance = MagicMock()
155
+ mock_sampling_metadata.from_input_batch.return_value = mock_sampling_instance
156
+
157
+ output = self.runner._prepare_inputs_non_dp(scheduler_output)
158
+ assert len(output) == 8
159
+ input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = output
160
+ # assert it will create attention metadata for each layer.
161
+ assert isinstance(attention_metadata, dict)
162
+ assert len(attention_metadata) == 20
163
+
164
+ def _create_mock_hybrid_kv_cache_config(self):
165
+ mock_kv_cache_config = MagicMock()
166
+ mock_kv_cache_group1 = MagicMock()
167
+ mock_kv_cache_group1.layer_names = [f'layer.{i}' for i in range(10)]
168
+ mock_kv_cache_group2 = MagicMock()
169
+ mock_kv_cache_group2.layer_names = [
170
+ f'layer.{i}' for i in range(10, 20)
171
+ ]
172
+ mock_kv_cache_config.kv_cache_groups = [
173
+ mock_kv_cache_group1, mock_kv_cache_group2
174
+ ]
175
+ self.runner.kv_cache_config = mock_kv_cache_config
176
+ self.runner.use_hybrid_kvcache = True
177
+
178
+
179
+ class TestTPUJaxRunnerMultimodalModelLoadedForTextOnly:
180
+
181
+ def setup_method(self):
182
+ # Mock JAX dependencies
183
+ self.mock_devices = [MagicMock(coords=i) for i in range(4)]
184
+ self.mock_rng_key = MagicMock()
185
+ device_array = np.array(jax.devices()[:1]).reshape(1, 1, 1, -1)
186
+ self.mock_mesh = jax.make_mesh(device_array.shape,
187
+ ('data', 'attn_dp', 'expert', 'model'))
188
+ # Setup the runner with the model_config.is_multimodal_model set to True but get_model returning None for embed_multimodal_fn and embed_input_ids_fn.
189
+ with patch('jax.devices', return_value=self.mock_devices), \
190
+ patch('jax.make_mesh', return_value=self.mock_mesh), \
191
+ patch('jax.random.key', return_value=self.mock_rng_key), \
192
+ patch('tpu_inference.runner.tpu_runner.nnx.Rngs', return_value=self.mock_rng_key), \
193
+ patch('tpu_inference.runner.tpu_runner.get_model', return_value=self._model_get_model()), \
194
+ patch('tpu_inference.runner.tpu_runner.make_optimized_mesh', return_value=self.mock_mesh):
195
+
196
+ model_config = ModelConfig(tokenizer_mode="auto",
197
+ trust_remote_code=False,
198
+ seed=0,
199
+ dtype='bfloat16')
200
+ # Set multimodal_config to not None, such that the is_multimodal_model property of model_config is True.
201
+ model_config.multimodal_config = MagicMock()
202
+
203
+ cache_config = CacheConfig(
204
+ block_size=16,
205
+ gpu_memory_utilization=0.9,
206
+ swap_space=4,
207
+ cache_dtype="auto",
208
+ )
209
+ scheduler_config = SchedulerConfig(max_num_seqs=16,
210
+ max_model_len=1024,
211
+ is_encoder_decoder=False)
212
+ parallel_config = ParallelConfig(
213
+ pipeline_parallel_size=1,
214
+ tensor_parallel_size=1,
215
+ worker_use_ray=False,
216
+ )
217
+ vllm_config = VllmConfig(
218
+ model_config=model_config,
219
+ cache_config=cache_config,
220
+ scheduler_config=scheduler_config,
221
+ parallel_config=parallel_config,
222
+ speculative_config=None,
223
+ observability_config={},
224
+ additional_config={},
225
+ )
226
+
227
+ self.runner = TPUModelRunner(vllm_config,
228
+ devices=self.mock_devices)
229
+ self.runner.load_model()
230
+
231
+ def _model_get_model(self):
232
+ mock_multimodal_fns = {
233
+ "precompile_vision_encoder_fn": None,
234
+ "embed_multimodal_fn": None,
235
+ "embed_input_ids_fn": None,
236
+ "get_mrope_input_positions_fn": None
237
+ }
238
+ return (
239
+ MagicMock(), # TPUModelRunner.model_fn
240
+ MagicMock(), # TPUModelRunner.compute_logits_fn
241
+ MagicMock(), # TPUModelRunner.combine_hidden_states_fn
242
+ mock_multimodal_fns, # TPUModelRunner.multimodal_fns
243
+ MagicMock(), # TPUModelRunner.state (model params)
244
+ None, # TPUModelRunner.lora_manager
245
+ None, # TPUModelRunner.model
246
+ )
247
+
248
+ def test_is_multimodal_model(self):
249
+ # Precondition: make sure the model_config claims the model supports MM.
250
+ assert self.runner.model_config.is_multimodal_model
251
+
252
+ # Precondition: load the model and returns embed_multimodal_fn as None.
253
+ assert self.runner.embed_multimodal_fn is None
254
+
255
+ assert not self.runner.is_multimodal_model
256
+
257
+ self.runner.embed_input_ids_fn = MagicMock()
258
+ dummy_input_ids = jnp.array([1, 2, 3])
259
+ dummy_mm_embeds = jnp.ones((10, 128))
260
+ _ = self.runner._get_input_ids_embeds(dummy_input_ids, dummy_mm_embeds)
261
+ self.runner.embed_input_ids_fn.assert_not_called()