tpu-inference 0.12.0.dev20251222__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (260) hide show
  1. tests/__init__.py +13 -0
  2. tests/core/__init__.py +13 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +67 -0
  6. tests/core/test_dp_scheduler.py +724 -0
  7. tests/core/test_init.py +63 -0
  8. tests/distributed/__init__.py +13 -0
  9. tests/distributed/test_distributed_utils.py +120 -0
  10. tests/distributed/test_tpu_connector.py +478 -0
  11. tests/e2e/__init__.py +13 -0
  12. tests/e2e/test_async_scheduler.py +211 -0
  13. tests/e2e/test_data_parallel.py +393 -0
  14. tests/e2e/test_local_disagg.py +257 -0
  15. tests/e2e/test_model_loader.py +268 -0
  16. tests/e2e/test_multi_modal_inference.py +111 -0
  17. tests/e2e/test_pipeline_parallel.py +265 -0
  18. tests/e2e/test_runai_model_streamer_loader.py +104 -0
  19. tests/e2e/test_sampling_params.py +269 -0
  20. tests/e2e/test_speculative_decoding.py +291 -0
  21. tests/e2e/test_structured_decoding.py +46 -0
  22. tests/executors/__init__.py +13 -0
  23. tests/executors/test_ray_distributed_executor.py +199 -0
  24. tests/experimental/__init__.py +13 -0
  25. tests/experimental/test_llama3_jax_stashed.py +208 -0
  26. tests/kernels/__init__.py +13 -0
  27. tests/kernels/collectives/__init__.py +13 -0
  28. tests/kernels/collectives/all_gather_matmul_kernel_test.py +69 -0
  29. tests/kernels/fused_moe_v1_test.py +388 -0
  30. tests/kernels/gmm_test.py +205 -0
  31. tests/kernels/mla_v1_test.py +498 -0
  32. tests/kernels/quantized_matmul_kernel_test.py +159 -0
  33. tests/kernels/ragged_kv_cache_update_v2_test.py +248 -0
  34. tests/kernels/ragged_paged_attention_kernel_v2_test.py +414 -0
  35. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +565 -0
  36. tests/kernels/ragged_paged_attention_kernel_v3_test.py +520 -0
  37. tests/layers/__init__.py +13 -0
  38. tests/layers/common/__init__.py +13 -0
  39. tests/layers/common/test_attention_interface.py +156 -0
  40. tests/layers/common/test_quantization.py +149 -0
  41. tests/layers/jax/__init__.py +13 -0
  42. tests/layers/jax/attention/__init__.py +13 -0
  43. tests/layers/jax/attention/test_common_attention.py +103 -0
  44. tests/layers/jax/attention/test_deepseek_v3_attention.py +233 -0
  45. tests/layers/jax/attention/test_llama4_attention.py +135 -0
  46. tests/layers/jax/moe/__init__.py +13 -0
  47. tests/layers/jax/moe/test_deepseek_moe.py +235 -0
  48. tests/layers/jax/sample/__init__.py +13 -0
  49. tests/layers/jax/sample/test_rejection_sampler.py +1624 -0
  50. tests/layers/jax/sample/test_sampling.py +115 -0
  51. tests/layers/jax/sample/test_sampling_metadata.py +254 -0
  52. tests/layers/jax/test_layers.py +155 -0
  53. tests/layers/jax/test_qwix.py +969 -0
  54. tests/layers/jax/test_rope.py +93 -0
  55. tests/layers/jax/test_sharding.py +159 -0
  56. tests/layers/jax/test_transformer_block.py +152 -0
  57. tests/layers/vllm/__init__.py +13 -0
  58. tests/layers/vllm/test_attention.py +363 -0
  59. tests/layers/vllm/test_awq.py +405 -0
  60. tests/layers/vllm/test_compressed_tensors_moe.py +202 -0
  61. tests/layers/vllm/test_compressed_tensors_w8a8_fp8.py +403 -0
  62. tests/layers/vllm/test_compressed_tensors_w8a8_int8.py +426 -0
  63. tests/layers/vllm/test_fp8.py +17 -0
  64. tests/layers/vllm/test_mxfp4.py +297 -0
  65. tests/layers/vllm/test_unquantized.py +621 -0
  66. tests/layers/vllm/utils.py +72 -0
  67. tests/lora/__init__.py +13 -0
  68. tests/lora/conftest.py +46 -0
  69. tests/lora/test_bgmv.py +57 -0
  70. tests/lora/test_layers.py +666 -0
  71. tests/lora/test_lora.py +147 -0
  72. tests/lora/test_lora_perf.py +67 -0
  73. tests/lora/utils.py +88 -0
  74. tests/models/__init__.py +13 -0
  75. tests/models/common/__init__.py +13 -0
  76. tests/models/common/test_model_loader.py +455 -0
  77. tests/models/jax/__init__.py +13 -0
  78. tests/models/jax/test_deepseek_v3.py +401 -0
  79. tests/models/jax/test_llama3.py +184 -0
  80. tests/models/jax/test_llama4.py +298 -0
  81. tests/models/jax/test_llama_eagle3.py +197 -0
  82. tests/models/jax/test_llama_guard_4.py +242 -0
  83. tests/models/jax/test_qwen2.py +172 -0
  84. tests/models/jax/test_qwen2_5_vl.py +606 -0
  85. tests/models/jax/test_qwen3.py +169 -0
  86. tests/models/jax/test_weight_loading.py +180 -0
  87. tests/models/jax/utils/__init__.py +13 -0
  88. tests/models/jax/utils/test_multi_modal_utils.py +212 -0
  89. tests/platforms/__init__.py +13 -0
  90. tests/platforms/test_tpu_platform.py +54 -0
  91. tests/runner/__init__.py +13 -0
  92. tests/runner/test_block_table.py +395 -0
  93. tests/runner/test_input_batch.py +226 -0
  94. tests/runner/test_kv_cache.py +220 -0
  95. tests/runner/test_kv_cache_manager.py +498 -0
  96. tests/runner/test_multimodal_manager.py +429 -0
  97. tests/runner/test_persistent_batch_manager.py +84 -0
  98. tests/runner/test_speculative_decoding_manager.py +368 -0
  99. tests/runner/test_structured_decoding_manager.py +220 -0
  100. tests/runner/test_tpu_runner.py +202 -0
  101. tests/runner/test_tpu_runner_dp.py +1033 -0
  102. tests/runner/test_tpu_runner_mesh.py +200 -0
  103. tests/runner/test_utils.py +411 -0
  104. tests/spec_decode/__init__.py +13 -0
  105. tests/spec_decode/test_eagle3.py +311 -0
  106. tests/test_base.py +215 -0
  107. tests/test_envs.py +280 -0
  108. tests/test_tpu_info.py +134 -0
  109. tests/test_utils.py +193 -0
  110. tests/worker/__init__.py +13 -0
  111. tests/worker/tpu_worker_test.py +414 -0
  112. tpu_inference/__init__.py +67 -0
  113. tpu_inference/core/__init__.py +13 -0
  114. tpu_inference/core/core_tpu.py +786 -0
  115. tpu_inference/core/disagg_executor.py +118 -0
  116. tpu_inference/core/disagg_utils.py +49 -0
  117. tpu_inference/core/sched/__init__.py +13 -0
  118. tpu_inference/core/sched/dp_scheduler.py +814 -0
  119. tpu_inference/distributed/__init__.py +13 -0
  120. tpu_inference/distributed/jax_parallel_state.py +81 -0
  121. tpu_inference/distributed/tpu_connector.py +732 -0
  122. tpu_inference/distributed/utils.py +112 -0
  123. tpu_inference/env_override.py +9 -0
  124. tpu_inference/envs.py +191 -0
  125. tpu_inference/executors/__init__.py +13 -0
  126. tpu_inference/executors/ray_distributed_executor.py +399 -0
  127. tpu_inference/experimental/__init__.py +13 -0
  128. tpu_inference/experimental/llama3_jax_stashed.py +272 -0
  129. tpu_inference/kernels/__init__.py +13 -0
  130. tpu_inference/kernels/collectives/__init__.py +13 -0
  131. tpu_inference/kernels/collectives/all_gather_matmul.py +741 -0
  132. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +65 -0
  133. tpu_inference/kernels/collectives/util.py +47 -0
  134. tpu_inference/kernels/flash_attention/__init__.py +13 -0
  135. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  136. tpu_inference/kernels/fused_moe/__init__.py +13 -0
  137. tpu_inference/kernels/fused_moe/v1/__init__.py +13 -0
  138. tpu_inference/kernels/fused_moe/v1/kernel.py +1612 -0
  139. tpu_inference/kernels/megablox/__init__.py +13 -0
  140. tpu_inference/kernels/megablox/common.py +54 -0
  141. tpu_inference/kernels/megablox/gmm.py +646 -0
  142. tpu_inference/kernels/mla/__init__.py +13 -0
  143. tpu_inference/kernels/mla/v1/__init__.py +13 -0
  144. tpu_inference/kernels/mla/v1/kernel.py +1340 -0
  145. tpu_inference/kernels/quantized_matmul/__init__.py +13 -0
  146. tpu_inference/kernels/quantized_matmul/kernel.py +456 -0
  147. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  148. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  149. tpu_inference/kernels/ragged_paged_attention/__init__.py +13 -0
  150. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +13 -0
  151. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +876 -0
  152. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +288 -0
  153. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  154. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +13 -0
  155. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1594 -0
  156. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1586 -0
  157. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4460 -0
  158. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +548 -0
  159. tpu_inference/kernels/ragged_paged_attention/v3/util.py +65 -0
  160. tpu_inference/layers/__init__.py +13 -0
  161. tpu_inference/layers/common/__init__.py +13 -0
  162. tpu_inference/layers/common/attention_interface.py +403 -0
  163. tpu_inference/layers/common/attention_metadata.py +48 -0
  164. tpu_inference/layers/common/binary_search.py +295 -0
  165. tpu_inference/layers/common/quant_methods.py +23 -0
  166. tpu_inference/layers/common/quantization.py +270 -0
  167. tpu_inference/layers/common/sharding.py +600 -0
  168. tpu_inference/layers/jax/__init__.py +13 -0
  169. tpu_inference/layers/jax/attention/__init__.py +13 -0
  170. tpu_inference/layers/jax/attention/attention.py +268 -0
  171. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +547 -0
  172. tpu_inference/layers/jax/attention/gpt_oss_attention.py +275 -0
  173. tpu_inference/layers/jax/attention/llama4_attention.py +167 -0
  174. tpu_inference/layers/jax/base.py +165 -0
  175. tpu_inference/layers/jax/constants.py +101 -0
  176. tpu_inference/layers/jax/layers.py +315 -0
  177. tpu_inference/layers/jax/misc.py +30 -0
  178. tpu_inference/layers/jax/moe/__init__.py +13 -0
  179. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +615 -0
  180. tpu_inference/layers/jax/moe/gpt_oss_moe.py +199 -0
  181. tpu_inference/layers/jax/moe/moe.py +249 -0
  182. tpu_inference/layers/jax/pp_utils.py +53 -0
  183. tpu_inference/layers/jax/rope.py +294 -0
  184. tpu_inference/layers/jax/rope_interface.py +228 -0
  185. tpu_inference/layers/jax/sample/__init__.py +13 -0
  186. tpu_inference/layers/jax/sample/rejection_sampler.py +528 -0
  187. tpu_inference/layers/jax/sample/sampling.py +110 -0
  188. tpu_inference/layers/jax/sample/sampling_metadata.py +90 -0
  189. tpu_inference/layers/jax/transformer_block.py +121 -0
  190. tpu_inference/layers/vllm/__init__.py +13 -0
  191. tpu_inference/layers/vllm/attention.py +221 -0
  192. tpu_inference/layers/vllm/fused_moe.py +502 -0
  193. tpu_inference/layers/vllm/linear_common.py +221 -0
  194. tpu_inference/layers/vllm/quantization/__init__.py +55 -0
  195. tpu_inference/layers/vllm/quantization/awq.py +221 -0
  196. tpu_inference/layers/vllm/quantization/common.py +124 -0
  197. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +13 -0
  198. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +135 -0
  199. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +266 -0
  200. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +13 -0
  201. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +222 -0
  202. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +150 -0
  203. tpu_inference/layers/vllm/quantization/fp8.py +118 -0
  204. tpu_inference/layers/vllm/quantization/mxfp4.py +396 -0
  205. tpu_inference/layers/vllm/quantization/unquantized.py +416 -0
  206. tpu_inference/layers/vllm/sharding.py +244 -0
  207. tpu_inference/logger.py +10 -0
  208. tpu_inference/lora/__init__.py +13 -0
  209. tpu_inference/lora/torch_lora_ops.py +98 -0
  210. tpu_inference/lora/torch_punica_tpu.py +310 -0
  211. tpu_inference/models/__init__.py +13 -0
  212. tpu_inference/models/common/__init__.py +13 -0
  213. tpu_inference/models/common/model_loader.py +520 -0
  214. tpu_inference/models/jax/__init__.py +13 -0
  215. tpu_inference/models/jax/deepseek_v3.py +978 -0
  216. tpu_inference/models/jax/gpt_oss.py +508 -0
  217. tpu_inference/models/jax/jax_intermediate_tensor.py +93 -0
  218. tpu_inference/models/jax/llama3.py +436 -0
  219. tpu_inference/models/jax/llama4.py +643 -0
  220. tpu_inference/models/jax/llama_eagle3.py +350 -0
  221. tpu_inference/models/jax/llama_guard_4.py +375 -0
  222. tpu_inference/models/jax/qwen2.py +390 -0
  223. tpu_inference/models/jax/qwen2_5_vl.py +1232 -0
  224. tpu_inference/models/jax/qwen3.py +318 -0
  225. tpu_inference/models/jax/utils/__init__.py +13 -0
  226. tpu_inference/models/jax/utils/file_utils.py +110 -0
  227. tpu_inference/models/jax/utils/multi_modal_utils.py +177 -0
  228. tpu_inference/models/jax/utils/qwix/__init__.py +13 -0
  229. tpu_inference/models/jax/utils/qwix/qwix_utils.py +713 -0
  230. tpu_inference/models/jax/utils/weight_utils.py +621 -0
  231. tpu_inference/models/vllm/__init__.py +13 -0
  232. tpu_inference/models/vllm/vllm_model_wrapper.py +307 -0
  233. tpu_inference/models/vllm/vllm_model_wrapper_context.py +59 -0
  234. tpu_inference/platforms/__init__.py +16 -0
  235. tpu_inference/platforms/tpu_platform.py +258 -0
  236. tpu_inference/runner/__init__.py +13 -0
  237. tpu_inference/runner/block_table.py +122 -0
  238. tpu_inference/runner/compilation_manager.py +890 -0
  239. tpu_inference/runner/input_batch.py +435 -0
  240. tpu_inference/runner/kv_cache.py +166 -0
  241. tpu_inference/runner/kv_cache_manager.py +508 -0
  242. tpu_inference/runner/lora_utils.py +106 -0
  243. tpu_inference/runner/multimodal_manager.py +231 -0
  244. tpu_inference/runner/persistent_batch_manager.py +296 -0
  245. tpu_inference/runner/speculative_decoding_manager.py +262 -0
  246. tpu_inference/runner/structured_decoding_manager.py +101 -0
  247. tpu_inference/runner/tpu_runner.py +1768 -0
  248. tpu_inference/runner/utils.py +426 -0
  249. tpu_inference/spec_decode/__init__.py +13 -0
  250. tpu_inference/spec_decode/jax/__init__.py +13 -0
  251. tpu_inference/spec_decode/jax/eagle3.py +430 -0
  252. tpu_inference/tpu_info.py +92 -0
  253. tpu_inference/utils.py +345 -0
  254. tpu_inference/worker/__init__.py +13 -0
  255. tpu_inference/worker/tpu_worker.py +468 -0
  256. tpu_inference-0.12.0.dev20251222.dist-info/METADATA +106 -0
  257. tpu_inference-0.12.0.dev20251222.dist-info/RECORD +260 -0
  258. tpu_inference-0.12.0.dev20251222.dist-info/WHEEL +5 -0
  259. tpu_inference-0.12.0.dev20251222.dist-info/licenses/LICENSE +201 -0
  260. tpu_inference-0.12.0.dev20251222.dist-info/top_level.txt +2 -0
@@ -0,0 +1,231 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ import jax
18
+ import jax.numpy as jnp
19
+ from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
20
+ from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
21
+ from vllm.multimodal.utils import group_mm_kwargs_by_modality
22
+ from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
23
+ from vllm.v1.worker.utils import (gather_mm_placeholders,
24
+ scatter_mm_placeholders)
25
+
26
+ from tpu_inference.models.jax.utils.multi_modal_utils import (
27
+ flatten_embeddings, sanity_check_mm_encoder_outputs)
28
+
29
+ if TYPE_CHECKING:
30
+ from tpu_inference.runner.tpu_runner import TPUModelRunner
31
+
32
+
33
+ class MultiModalManager:
34
+
35
+ def __init__(self, runner: "TPUModelRunner"):
36
+ self.runner = runner
37
+
38
+ def calc_mrope_positions(self, scheduler_output: "VllmSchedulerOutput"):
39
+ mrope_pos_ptr = 0
40
+ for index, req_id in enumerate(self.runner.input_batch.req_ids):
41
+ req = self.runner.requests[req_id]
42
+ assert req.mrope_positions is not None
43
+
44
+ num_computed_tokens = \
45
+ self.runner.input_batch.num_computed_tokens_cpu[index]
46
+ num_scheduled_tokens = \
47
+ scheduler_output.num_scheduled_tokens[req_id]
48
+ num_prompt_tokens = len(req.prompt_token_ids)
49
+
50
+ if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
51
+ prompt_part_len = max(0,
52
+ num_prompt_tokens - num_computed_tokens)
53
+ completion_part_len = max(
54
+ 0, num_scheduled_tokens - prompt_part_len)
55
+ else:
56
+ prompt_part_len = num_scheduled_tokens
57
+ completion_part_len = 0
58
+
59
+ assert num_scheduled_tokens == prompt_part_len + completion_part_len
60
+
61
+ if prompt_part_len > 0:
62
+ # prompt's mrope_positions are pre-computed
63
+ dst_start = mrope_pos_ptr
64
+ dst_end = mrope_pos_ptr + prompt_part_len
65
+ src_start = num_computed_tokens
66
+ src_end = num_computed_tokens + prompt_part_len
67
+
68
+ self.runner.mrope_positions_cpu[:, dst_start:dst_end] = \
69
+ req.mrope_positions[:,src_start:src_end]
70
+
71
+ mrope_pos_ptr += prompt_part_len
72
+
73
+ if completion_part_len > 0:
74
+ # compute completion's mrope_positions on-the-fly
75
+ dst_start = mrope_pos_ptr
76
+ dst_end = mrope_pos_ptr + completion_part_len
77
+
78
+ MRotaryEmbedding.get_next_input_positions_tensor(
79
+ out=self.runner.mrope_positions_cpu,
80
+ out_offset=dst_start,
81
+ mrope_position_delta=req.mrope_position_delta,
82
+ context_len=num_computed_tokens + prompt_part_len,
83
+ num_new_tokens=completion_part_len,
84
+ )
85
+
86
+ mrope_pos_ptr += completion_part_len
87
+
88
+ def execute_mm_encoder(self, scheduler_output: "VllmSchedulerOutput"):
89
+ import torch
90
+ scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
91
+ if not scheduled_encoder_inputs:
92
+ return
93
+
94
+ # Batch the multi-modal inputs.
95
+ mm_kwargs = list[MultiModalKwargsItem]()
96
+ # List of tuple (mm_hash, pos_info)
97
+ mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
98
+ for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
99
+ req_state = self.runner.requests[req_id]
100
+ for mm_input_id in encoder_input_ids:
101
+ mm_feature = req_state.mm_features[mm_input_id]
102
+ mm_hash = mm_feature.identifier
103
+ mm_kwargs.append(mm_feature.data)
104
+ mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
105
+
106
+ # Batch mm inputs as much as we can: if a request in the batch has
107
+ # multiple modalities or a different modality than the previous one,
108
+ # we process it separately to preserve item order.
109
+ # FIXME(ywang96): This is a hacky way to deal with multiple modalities
110
+ # in the same batch while still being able to benefit from batching
111
+ # multimodal inputs. The proper solution should be reordering the
112
+ # encoder outputs.
113
+ encoder_outputs = []
114
+ for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
115
+ mm_kwargs, merge_by_field_config=False):
116
+ batched_mm_inputs = mm_kwargs_group
117
+ # Convert torch tensors to numpy arrays that JAX can handle.
118
+ if "pixel_values" in batched_mm_inputs and isinstance(
119
+ batched_mm_inputs["pixel_values"], list):
120
+ batched_mm_inputs["pixel_values"] = torch.cat(
121
+ batched_mm_inputs["pixel_values"], dim=0)
122
+
123
+ image_grid_thw = ()
124
+ for key, value in batched_mm_inputs.items():
125
+ if isinstance(value, torch.Tensor):
126
+ if key == 'image_grid_thw':
127
+ # change it to tuple of tuples to make it hashable for JIT
128
+
129
+ # Shape: (B, N, 3) -> (B*N, 3) -> tuple of tuples
130
+ grid_thw_tensor = batched_mm_inputs[key]
131
+ grid_thw_reshaped = grid_thw_tensor.reshape(-1, 3)
132
+ image_grid_thw = tuple(
133
+ tuple(row) for row in grid_thw_reshaped.tolist())
134
+
135
+ continue
136
+
137
+ if value.dtype == torch.bfloat16:
138
+ batched_mm_inputs[key] = value.to(
139
+ torch.float32).numpy().astype(jnp.bfloat16)
140
+ else:
141
+ batched_mm_inputs[key] = value.numpy()
142
+ batched_mm_inputs.pop('image_grid_thw')
143
+
144
+ # Run the encoder.
145
+ # `curr_group_outputs` is either of the following:
146
+ # 1. A tensor of shape (num_items, feature_size, hidden_size)
147
+ # in case feature_size is fixed across all multimodal items.
148
+ # 2. A list or tuple (length: num_items) of tensors, each of shape
149
+ # (feature_size, hidden_size) in case the feature size is dynamic
150
+ # depending on the input multimodal items.
151
+ curr_group_outputs = self.runner.get_multimodal_embeddings_fn(
152
+ self.runner.state, image_grid_thw, **batched_mm_inputs)
153
+
154
+ sanity_check_mm_encoder_outputs(
155
+ curr_group_outputs,
156
+ expected_num_items=num_items,
157
+ )
158
+
159
+ for output in curr_group_outputs:
160
+ encoder_outputs.append(output)
161
+
162
+ # Cache the encoder outputs.
163
+ for (mm_hash, pos_info), output in zip(
164
+ mm_hashes_pos,
165
+ encoder_outputs,
166
+ ):
167
+ if req_id not in self.runner.encoder_cache:
168
+ self.runner.encoder_cache[req_id] = {}
169
+
170
+ self.runner.encoder_cache[mm_hash] = scatter_mm_placeholders(
171
+ output,
172
+ is_embed=pos_info.is_embed,
173
+ )
174
+
175
+ def gather_mm_embeddings(self, scheduler_output: "VllmSchedulerOutput",
176
+ target_pad_len: int) -> list[jax.Array]:
177
+ mm_embeds: list[jax.Array] = []
178
+ for req_id in self.runner.input_batch.req_ids:
179
+ num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
180
+ req_id]
181
+ req_state = self.runner.requests[req_id]
182
+ num_computed_tokens = req_state.num_computed_tokens
183
+ mm_features = req_state.mm_features
184
+ for _, mm_feature in enumerate(mm_features):
185
+ pos_info = mm_feature.mm_position
186
+ start_pos = pos_info.offset
187
+ num_encoder_tokens = pos_info.length
188
+
189
+ # The encoder output is needed if the two ranges overlap:
190
+ # [num_computed_tokens,
191
+ # num_computed_tokens + num_scheduled_tokens) and
192
+ # [start_pos, start_pos + num_encoder_tokens)
193
+ if start_pos >= num_computed_tokens + num_scheduled_tokens:
194
+ # The encoder output is not needed in this step.
195
+ break
196
+ if start_pos + num_encoder_tokens <= num_computed_tokens:
197
+ # The encoder output is already processed and stored
198
+ # in the decoder's KV cache.
199
+ continue
200
+
201
+ start_idx = max(num_computed_tokens - start_pos, 0)
202
+ end_idx = min(
203
+ num_computed_tokens - start_pos + num_scheduled_tokens,
204
+ num_encoder_tokens)
205
+ assert start_idx < end_idx
206
+ mm_hash = mm_feature.identifier
207
+ encoder_output = self.runner.encoder_cache.get(mm_hash, None)
208
+ assert encoder_output is not None,\
209
+ f"Encoder cache miss for {mm_hash}."
210
+ encoder_output = self.runner.encoder_cache[mm_hash]
211
+
212
+ if (is_embed := pos_info.is_embed) is not None:
213
+ is_embed = is_embed[start_idx:end_idx]
214
+
215
+ mm_embeds_item = gather_mm_placeholders(
216
+ encoder_output[start_idx:end_idx],
217
+ is_embed=is_embed,
218
+ )
219
+ mm_embeds.append(mm_embeds_item)
220
+ if not mm_embeds:
221
+ return None
222
+ flattened_embeds = flatten_embeddings(mm_embeds)
223
+ if flattened_embeds.shape[0] == 0:
224
+ return None
225
+
226
+ padding = jnp.zeros((target_pad_len - flattened_embeds.shape[0],
227
+ flattened_embeds.shape[1]),
228
+ dtype=flattened_embeds.dtype)
229
+ flattened_embeds = jnp.concatenate([flattened_embeds, padding], axis=0)
230
+
231
+ return flattened_embeds
@@ -0,0 +1,296 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict
16
+
17
+ import jax
18
+ from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
19
+
20
+ from tpu_inference.logger import init_logger
21
+ from tpu_inference.runner.input_batch import CachedRequestState, InputBatch
22
+
23
+ logger = init_logger(__name__)
24
+
25
+
26
+ class PersistentBatchManager:
27
+
28
+ def __init__(self, requests: Dict[str, CachedRequestState],
29
+ input_batch: InputBatch, encoder_cache: Dict[str,
30
+ 'jax.Array'],
31
+ uses_mrope: bool, model_config, is_last_rank: bool):
32
+ self.requests = requests
33
+ self.input_batch = input_batch
34
+ self.encoder_cache = encoder_cache
35
+ self.uses_mrope = uses_mrope
36
+ self.model_config = model_config
37
+ self.is_last_rank = is_last_rank
38
+
39
+ def _reorder_batch(self, scheduler_output: "VllmSchedulerOutput") -> int:
40
+ """ Reorder the sheduled requests to RPA kernel friendly distribution
41
+ (decode_only, fixed_chunked_prefill_only, mixed) and set the request
42
+ distribution accordingly.
43
+
44
+ Returns:
45
+ The number of swaps in requests.
46
+ """
47
+ # Note(jevinjiang): currently we only consider decode_only.
48
+ num_reqs = self.input_batch.num_reqs
49
+ swap_cnt = 0
50
+ if num_reqs <= 0:
51
+ return swap_cnt
52
+ # Use two-pointer approach to reorder the decode requests to front.
53
+ i, j = 0, num_reqs - 1
54
+ while i < j:
55
+ i_req_id = self.input_batch.req_ids[i]
56
+ j_req_id = self.input_batch.req_ids[j]
57
+
58
+ if scheduler_output.num_scheduled_tokens[i_req_id] == 1:
59
+ # i is a decode request, move to the next one.
60
+ i += 1
61
+ elif scheduler_output.num_scheduled_tokens[j_req_id] > 1:
62
+ # j is a prefill request, move to the previous one.
63
+ j -= 1
64
+ else:
65
+ # Swap i and j.
66
+ self.input_batch.swap_states(i, j)
67
+ i += 1
68
+ j -= 1
69
+ swap_cnt += 1
70
+
71
+ num_decode = i + int(scheduler_output.num_scheduled_tokens[
72
+ self.input_batch.req_ids[i]] == 1)
73
+
74
+ self.input_batch.request_distribution = [
75
+ num_decode, num_decode, num_reqs
76
+ ]
77
+
78
+ return swap_cnt
79
+
80
+ def update_states(self, scheduler_output: "VllmSchedulerOutput",
81
+ get_mrope_input_positions_fn) -> bool:
82
+ """Update the cached states and the persistent batch with the scheduler
83
+ output.
84
+
85
+ The updated states are used by the `_prepare_inputs` function to create
86
+ the input TPU tensors for the model.
87
+
88
+ Returns:
89
+ True if there is a new/resumed/paused/finished request.
90
+ If False, we can skip copying SamplingMetadata to the TPU.
91
+ """
92
+ # Remove finished requests from the cached states.
93
+ for req_id in scheduler_output.finished_req_ids:
94
+ self.requests.pop(req_id, None)
95
+
96
+ # Remove the finished requests from the persistent batch.
97
+ # NOTE(woosuk): There could be an edge case where finished_req_ids and
98
+ # scheduled_req_ids overlap. This happens when a request is aborted and
99
+ # then resubmitted with the same ID. In this case, we treat them as two
100
+ # distinct requests - clearing the cached states for the first request
101
+ # and handling the second as a new request.
102
+ removed_req_indices: list[int] = []
103
+ for req_id in scheduler_output.finished_req_ids:
104
+ req_index = self.input_batch.remove_request(req_id)
105
+ if req_index is not None:
106
+ removed_req_indices.append(req_index)
107
+
108
+ # Free the cached encoder outputs.
109
+ for mm_hash in scheduler_output.free_encoder_mm_hashes:
110
+ self.encoder_cache.pop(mm_hash, None)
111
+
112
+ # Remove the unscheduled requests from the persistent batch.
113
+ # NOTE(woosuk): The unscheduled requests are either preempted requests
114
+ # or running requests that are not scheduled in this step. We remove
115
+ # them from the persistent batch but keep their cached states since
116
+ # they will be scheduled again sometime in the future.
117
+ scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
118
+ cached_req_ids = self.input_batch.req_id_to_index.keys()
119
+ unscheduled_req_ids = cached_req_ids - scheduled_req_ids
120
+ # NOTE(woosuk): The persistent batch optimization assumes that
121
+ # consecutive batches contain mostly the same requests. If batches
122
+ # have low request overlap (e.g., alternating between two distinct
123
+ # sets of requests), this optimization becomes very inefficient.
124
+ for req_id in unscheduled_req_ids:
125
+ req_index = self.input_batch.remove_request(req_id)
126
+ assert req_index is not None
127
+ removed_req_indices.append(req_index)
128
+
129
+ req_ids_to_add: list[str] = []
130
+ # Add new requests to the cached states.
131
+ for new_req_data in scheduler_output.scheduled_new_reqs:
132
+ req_id = new_req_data.req_id
133
+ sampling_params = new_req_data.sampling_params
134
+
135
+ self.requests[req_id] = CachedRequestState(
136
+ req_id=req_id,
137
+ prompt_token_ids=new_req_data.prompt_token_ids,
138
+ mm_features=new_req_data.mm_features,
139
+ sampling_params=sampling_params,
140
+ pooling_params=None,
141
+ generator=None,
142
+ block_ids=new_req_data.block_ids,
143
+ num_computed_tokens=new_req_data.num_computed_tokens,
144
+ output_token_ids=[],
145
+ lora_request=new_req_data.lora_request,
146
+ )
147
+
148
+ req_ids_to_add.append(req_id)
149
+
150
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
151
+ if self.uses_mrope:
152
+ image_grid_thw = []
153
+ video_grid_thw = []
154
+ second_per_grid_ts = []
155
+ audio_feature_lengths = []
156
+ use_audio_in_video = False
157
+ for mm_feature in self.requests[req_id].mm_features:
158
+ item = mm_feature.data
159
+ if item is None:
160
+ continue
161
+ mm_input = item.get_data()
162
+ if mm_input.get("image_grid_thw") is not None:
163
+ image_grid_thw.append(
164
+ mm_input["image_grid_thw"].tolist())
165
+ if mm_input.get("video_grid_thw") is not None:
166
+ video_grid_thw.append(
167
+ mm_input["video_grid_thw"].tolist())
168
+ if mm_input.get("second_per_grid_ts") is not None:
169
+ second_per_grid_ts.append(
170
+ mm_input["second_per_grid_ts"])
171
+ if mm_input.get("audio_feature_lengths") is not None:
172
+ audio_feature_lengths.append(
173
+ mm_input["audio_feature_lengths"])
174
+ if mm_input.get("use_audio_in_video") is True:
175
+ use_audio_in_video = True
176
+
177
+ hf_config = self.model_config.hf_config
178
+
179
+ self.requests[req_id].mrope_positions, self.requests[
180
+ req_id].mrope_position_delta = get_mrope_input_positions_fn(
181
+ self.requests[req_id].prompt_token_ids,
182
+ hf_config=hf_config,
183
+ image_grid_thw=image_grid_thw,
184
+ video_grid_thw=video_grid_thw,
185
+ second_per_grid_ts=second_per_grid_ts,
186
+ audio_feature_lengths=audio_feature_lengths,
187
+ use_audio_in_video=use_audio_in_video,
188
+ )
189
+
190
+ # Update the states of the running/resumed requests.
191
+ req_data = scheduler_output.scheduled_cached_reqs
192
+ for i, req_id in enumerate(req_data.req_ids):
193
+ req_state = self.requests[req_id]
194
+ num_computed_tokens = req_data.num_computed_tokens[i]
195
+ new_block_ids = req_data.new_block_ids[i]
196
+ resumed_from_preemption = req_data.resumed_from_preemption[i]
197
+ num_output_tokens = req_data.num_output_tokens[i]
198
+
199
+ # Update the cached states.
200
+ req_state.num_computed_tokens = num_computed_tokens
201
+ req_index = self.input_batch.req_id_to_index.get(req_id)
202
+
203
+ if not self.is_last_rank:
204
+ # When using PP, the scheduler sends the sampled tokens back,
205
+ # because there's no direct communication between the first-
206
+ # stage worker and the last-stage worker.
207
+ new_token_ids = req_data.new_token_ids[i]
208
+ # Add the sampled token(s) from the previous step (if any).
209
+ # This doesn't include "unverified" tokens like spec tokens.
210
+ num_new_tokens = (num_computed_tokens + len(new_token_ids) -
211
+ req_state.num_tokens)
212
+ if num_new_tokens == 1:
213
+ req_state.output_token_ids.append(new_token_ids[-1])
214
+ elif num_new_tokens > 0:
215
+ req_state.output_token_ids.extend(
216
+ new_token_ids[-num_new_tokens:])
217
+ elif num_output_tokens < len(req_state.output_token_ids):
218
+ del req_state.output_token_ids[num_output_tokens:]
219
+ if req_index is not None:
220
+ end_idx = (self.input_batch.num_prompt_tokens[req_index] +
221
+ num_output_tokens)
222
+ self.input_batch.num_tokens[req_index] = end_idx
223
+ self.input_batch.num_tokens_no_spec[req_index] = end_idx
224
+
225
+ # Update the block IDs.
226
+ if not resumed_from_preemption:
227
+ if new_block_ids is not None:
228
+ # Append the new blocks to the existing block IDs.
229
+ for block_ids, new_ids in zip(req_state.block_ids,
230
+ new_block_ids):
231
+ block_ids.extend(new_ids)
232
+ else:
233
+ assert new_block_ids is not None
234
+ # The request is resumed from preemption.
235
+ # Replace the existing block IDs with the new ones.
236
+ req_state.block_ids = new_block_ids
237
+
238
+ if req_index is None:
239
+ # The request is not in the persistent batch.
240
+ # The request was either preempted and resumed later, or was not
241
+ # scheduled in the previous step and needs to be added again.
242
+ req_ids_to_add.append(req_id)
243
+ continue
244
+
245
+ # Update the persistent batch.
246
+ self.input_batch.num_computed_tokens_cpu[
247
+ req_index] = num_computed_tokens
248
+ if new_block_ids is not None:
249
+ self.input_batch.block_table.append_row(
250
+ new_block_ids, req_index)
251
+
252
+ # For the last rank, we don't need to update the token_ids_cpu
253
+ # because the sampled tokens are already cached.
254
+ if not self.is_last_rank:
255
+ start_token_index = num_computed_tokens
256
+ end_token_index = num_computed_tokens + len(new_token_ids)
257
+ self.input_batch.token_ids_cpu[
258
+ req_index,
259
+ start_token_index:end_token_index] = new_token_ids
260
+ self.input_batch.num_tokens_no_spec[
261
+ req_index] = end_token_index
262
+ self.input_batch.num_tokens[req_index] = end_token_index
263
+
264
+ # Add spec_token_ids to token_ids_cpu.
265
+ spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
266
+ req_id, ())
267
+ if spec_token_ids:
268
+ num_spec_tokens = len(spec_token_ids)
269
+ start_index = self.input_batch.num_tokens_no_spec[req_index]
270
+ end_token_index = start_index + num_spec_tokens
271
+ self.input_batch.token_ids_cpu[
272
+ req_index, start_index:end_token_index] = spec_token_ids
273
+ # NOTE(woosuk): `num_tokens` here may include spec tokens.
274
+ self.input_batch.num_tokens[req_index] += num_spec_tokens
275
+
276
+ # Add the new or resumed requests to the persistent batch.
277
+ # The smaller empty indices are filled first.
278
+ removed_req_indices = sorted(removed_req_indices, reverse=True)
279
+ for req_id in req_ids_to_add:
280
+ req_state = self.requests[req_id]
281
+ if removed_req_indices:
282
+ # Fill the empty index.
283
+ req_index = removed_req_indices.pop()
284
+ else:
285
+ # Append to the end.
286
+ req_index = None
287
+ self.input_batch.add_request(req_state, req_index)
288
+
289
+ # Condense the batched states if there are empty indices.
290
+ if removed_req_indices:
291
+ self.input_batch.condense(removed_req_indices)
292
+
293
+ batch_changed = len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
294
+ # TODO(jevinjiang): I assume we do not need to set batch_changed to true if just swapping requests.
295
+ self._reorder_batch(scheduler_output)
296
+ return batch_changed