tpu-inference 0.0.1rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (174) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_core_tpu.py +513 -0
  4. tests/core/test_disagg_executor.py +60 -0
  5. tests/core/test_disagg_utils.py +53 -0
  6. tests/core/test_dp_scheduler.py +899 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/fused_moe_v1_test.py +374 -0
  10. tests/kernels/mla_v1_test.py +396 -0
  11. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  12. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  13. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  14. tests/kernels/ragged_paged_attention_kernel_v3_hd64_test.py +549 -0
  15. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  16. tests/lora/__init__.py +0 -0
  17. tests/lora/conftest.py +32 -0
  18. tests/lora/test_bgmv.py +43 -0
  19. tests/lora/test_layers.py +648 -0
  20. tests/lora/test_lora.py +133 -0
  21. tests/lora/utils.py +88 -0
  22. tests/test_base.py +201 -0
  23. tests/test_envs.py +203 -0
  24. tests/test_quantization.py +836 -0
  25. tests/test_tpu_info.py +120 -0
  26. tests/test_utils.py +235 -0
  27. tpu_inference/__init__.py +53 -0
  28. tpu_inference/core/__init__.py +0 -0
  29. tpu_inference/core/core_tpu.py +786 -0
  30. tpu_inference/core/disagg_executor.py +118 -0
  31. tpu_inference/core/disagg_utils.py +49 -0
  32. tpu_inference/core/sched/__init__.py +0 -0
  33. tpu_inference/core/sched/dp_scheduler.py +523 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/jax_parallel_state.py +67 -0
  36. tpu_inference/distributed/tpu_connector.py +727 -0
  37. tpu_inference/distributed/utils.py +60 -0
  38. tpu_inference/env_override.py +9 -0
  39. tpu_inference/envs.py +160 -0
  40. tpu_inference/executors/__init__.py +0 -0
  41. tpu_inference/executors/ray_distributed_executor.py +382 -0
  42. tpu_inference/experimental/__init__.py +0 -0
  43. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  44. tpu_inference/kernels/__init__.py +0 -0
  45. tpu_inference/kernels/collectives/__init__.py +0 -0
  46. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  47. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  48. tpu_inference/kernels/collectives/util.py +47 -0
  49. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  50. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  51. tpu_inference/kernels/fused_moe/__init__.py +0 -0
  52. tpu_inference/kernels/fused_moe/v1/__init__.py +0 -0
  53. tpu_inference/kernels/fused_moe/v1/kernel.py +1566 -0
  54. tpu_inference/kernels/mla/__init__.py +0 -0
  55. tpu_inference/kernels/mla/v1/__init__.py +0 -0
  56. tpu_inference/kernels/mla/v1/kernel.py +1349 -0
  57. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  58. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  59. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  60. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  61. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  62. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  66. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1501 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel_hd64.py +1603 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +4147 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes_hd64.py +367 -0
  71. tpu_inference/kernels/ragged_paged_attention/v3/util.py +51 -0
  72. tpu_inference/layers/__init__.py +0 -0
  73. tpu_inference/layers/common/__init__.py +0 -0
  74. tpu_inference/layers/common/attention_interface.py +396 -0
  75. tpu_inference/layers/common/attention_metadata.py +34 -0
  76. tpu_inference/layers/common/binary_search.py +295 -0
  77. tpu_inference/layers/common/quant_methods.py +8 -0
  78. tpu_inference/layers/common/sharding.py +582 -0
  79. tpu_inference/layers/jax/__init__.py +0 -0
  80. tpu_inference/layers/jax/attention/__init__.py +0 -0
  81. tpu_inference/layers/jax/attention/attention.py +255 -0
  82. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  83. tpu_inference/layers/jax/attention/gpt_oss_attention.py +262 -0
  84. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  85. tpu_inference/layers/jax/base.py +151 -0
  86. tpu_inference/layers/jax/constants.py +88 -0
  87. tpu_inference/layers/jax/layers.py +301 -0
  88. tpu_inference/layers/jax/misc.py +16 -0
  89. tpu_inference/layers/jax/moe/__init__.py +0 -0
  90. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  91. tpu_inference/layers/jax/moe/gpt_oss_moe.py +185 -0
  92. tpu_inference/layers/jax/moe/moe.py +209 -0
  93. tpu_inference/layers/jax/rope.py +280 -0
  94. tpu_inference/layers/jax/rope_interface.py +214 -0
  95. tpu_inference/layers/jax/sample/__init__.py +0 -0
  96. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  97. tpu_inference/layers/jax/sample/sampling.py +96 -0
  98. tpu_inference/layers/jax/sample/sampling_metadata.py +76 -0
  99. tpu_inference/layers/jax/transformer_block.py +107 -0
  100. tpu_inference/layers/vllm/__init__.py +0 -0
  101. tpu_inference/layers/vllm/attention.py +221 -0
  102. tpu_inference/layers/vllm/fused_moe.py +469 -0
  103. tpu_inference/layers/vllm/linear_common.py +186 -0
  104. tpu_inference/layers/vllm/quantization/__init__.py +39 -0
  105. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  106. tpu_inference/layers/vllm/quantization/common.py +110 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  108. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +120 -0
  109. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors_moe.py +203 -0
  110. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  111. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  112. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  113. tpu_inference/layers/vllm/quantization/mxfp4.py +331 -0
  114. tpu_inference/layers/vllm/quantization/unquantized.py +368 -0
  115. tpu_inference/layers/vllm/sharding.py +230 -0
  116. tpu_inference/logger.py +10 -0
  117. tpu_inference/lora/__init__.py +0 -0
  118. tpu_inference/lora/torch_lora_ops.py +103 -0
  119. tpu_inference/lora/torch_punica_tpu.py +310 -0
  120. tpu_inference/models/__init__.py +0 -0
  121. tpu_inference/models/common/__init__.py +0 -0
  122. tpu_inference/models/common/model_loader.py +478 -0
  123. tpu_inference/models/jax/__init__.py +0 -0
  124. tpu_inference/models/jax/deepseek_v3.py +868 -0
  125. tpu_inference/models/jax/gpt_oss.py +492 -0
  126. tpu_inference/models/jax/jax_intermediate_tensor.py +79 -0
  127. tpu_inference/models/jax/llama3.py +376 -0
  128. tpu_inference/models/jax/llama4.py +629 -0
  129. tpu_inference/models/jax/llama_eagle3.py +336 -0
  130. tpu_inference/models/jax/llama_guard_4.py +361 -0
  131. tpu_inference/models/jax/qwen2.py +376 -0
  132. tpu_inference/models/jax/qwen2_5_vl.py +1218 -0
  133. tpu_inference/models/jax/qwen3.py +303 -0
  134. tpu_inference/models/jax/utils/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/file_utils.py +96 -0
  136. tpu_inference/models/jax/utils/multi_modal_utils.py +163 -0
  137. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  138. tpu_inference/models/jax/utils/quantization/configs/fp8_all_modules_w_only.yaml +5 -0
  139. tpu_inference/models/jax/utils/quantization/configs/fp8_default.yaml +6 -0
  140. tpu_inference/models/jax/utils/quantization/configs/int8_all_modules_w_only.yaml +5 -0
  141. tpu_inference/models/jax/utils/quantization/configs/int8_default.yaml +6 -0
  142. tpu_inference/models/jax/utils/quantization/mxfp4_utils.py +105 -0
  143. tpu_inference/models/jax/utils/quantization/quantization_utils.py +650 -0
  144. tpu_inference/models/jax/utils/weight_utils.py +584 -0
  145. tpu_inference/models/vllm/__init__.py +0 -0
  146. tpu_inference/models/vllm/vllm_model_wrapper.py +293 -0
  147. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  148. tpu_inference/platforms/__init__.py +2 -0
  149. tpu_inference/platforms/tpu_platform.py +275 -0
  150. tpu_inference/runner/__init__.py +0 -0
  151. tpu_inference/runner/block_table.py +122 -0
  152. tpu_inference/runner/compilation_manager.py +865 -0
  153. tpu_inference/runner/input_batch.py +435 -0
  154. tpu_inference/runner/kv_cache.py +132 -0
  155. tpu_inference/runner/kv_cache_manager.py +478 -0
  156. tpu_inference/runner/lora_utils.py +92 -0
  157. tpu_inference/runner/multimodal_manager.py +217 -0
  158. tpu_inference/runner/persistent_batch_manager.py +282 -0
  159. tpu_inference/runner/speculative_decoding_manager.py +248 -0
  160. tpu_inference/runner/structured_decoding_manager.py +87 -0
  161. tpu_inference/runner/tpu_runner.py +1744 -0
  162. tpu_inference/runner/utils.py +426 -0
  163. tpu_inference/spec_decode/__init__.py +0 -0
  164. tpu_inference/spec_decode/jax/__init__.py +0 -0
  165. tpu_inference/spec_decode/jax/eagle3.py +417 -0
  166. tpu_inference/tpu_info.py +78 -0
  167. tpu_inference/utils.py +340 -0
  168. tpu_inference/worker/__init__.py +0 -0
  169. tpu_inference/worker/tpu_worker.py +458 -0
  170. tpu_inference-0.0.1rc1.dist-info/METADATA +108 -0
  171. tpu_inference-0.0.1rc1.dist-info/RECORD +174 -0
  172. tpu_inference-0.0.1rc1.dist-info/WHEEL +5 -0
  173. tpu_inference-0.0.1rc1.dist-info/licenses/LICENSE +201 -0
  174. tpu_inference-0.0.1rc1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,217 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
6
+ from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
7
+ from vllm.multimodal.utils import group_mm_kwargs_by_modality
8
+ from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
9
+ from vllm.v1.worker.utils import (gather_mm_placeholders,
10
+ scatter_mm_placeholders)
11
+
12
+ from tpu_inference.models.jax.utils.multi_modal_utils import (
13
+ flatten_embeddings, sanity_check_mm_encoder_outputs)
14
+
15
+ if TYPE_CHECKING:
16
+ from tpu_inference.runner.tpu_runner import TPUModelRunner
17
+
18
+
19
+ class MultiModalManager:
20
+
21
+ def __init__(self, runner: "TPUModelRunner"):
22
+ self.runner = runner
23
+
24
+ def calc_mrope_positions(self, scheduler_output: "VllmSchedulerOutput"):
25
+ mrope_pos_ptr = 0
26
+ for index, req_id in enumerate(self.runner.input_batch.req_ids):
27
+ req = self.runner.requests[req_id]
28
+ assert req.mrope_positions is not None
29
+
30
+ num_computed_tokens = \
31
+ self.runner.input_batch.num_computed_tokens_cpu[index]
32
+ num_scheduled_tokens = \
33
+ scheduler_output.num_scheduled_tokens[req_id]
34
+ num_prompt_tokens = len(req.prompt_token_ids)
35
+
36
+ if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
37
+ prompt_part_len = max(0,
38
+ num_prompt_tokens - num_computed_tokens)
39
+ completion_part_len = max(
40
+ 0, num_scheduled_tokens - prompt_part_len)
41
+ else:
42
+ prompt_part_len = num_scheduled_tokens
43
+ completion_part_len = 0
44
+
45
+ assert num_scheduled_tokens == prompt_part_len + completion_part_len
46
+
47
+ if prompt_part_len > 0:
48
+ # prompt's mrope_positions are pre-computed
49
+ dst_start = mrope_pos_ptr
50
+ dst_end = mrope_pos_ptr + prompt_part_len
51
+ src_start = num_computed_tokens
52
+ src_end = num_computed_tokens + prompt_part_len
53
+
54
+ self.runner.mrope_positions_cpu[:, dst_start:dst_end] = \
55
+ req.mrope_positions[:,src_start:src_end]
56
+
57
+ mrope_pos_ptr += prompt_part_len
58
+
59
+ if completion_part_len > 0:
60
+ # compute completion's mrope_positions on-the-fly
61
+ dst_start = mrope_pos_ptr
62
+ dst_end = mrope_pos_ptr + completion_part_len
63
+
64
+ MRotaryEmbedding.get_next_input_positions_tensor(
65
+ out=self.runner.mrope_positions_cpu,
66
+ out_offset=dst_start,
67
+ mrope_position_delta=req.mrope_position_delta,
68
+ context_len=num_computed_tokens + prompt_part_len,
69
+ num_new_tokens=completion_part_len,
70
+ )
71
+
72
+ mrope_pos_ptr += completion_part_len
73
+
74
+ def execute_mm_encoder(self, scheduler_output: "VllmSchedulerOutput"):
75
+ import torch
76
+ scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
77
+ if not scheduled_encoder_inputs:
78
+ return
79
+
80
+ # Batch the multi-modal inputs.
81
+ mm_kwargs = list[MultiModalKwargsItem]()
82
+ # List of tuple (mm_hash, pos_info)
83
+ mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
84
+ for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
85
+ req_state = self.runner.requests[req_id]
86
+ for mm_input_id in encoder_input_ids:
87
+ mm_feature = req_state.mm_features[mm_input_id]
88
+ mm_hash = mm_feature.identifier
89
+ mm_kwargs.append(mm_feature.data)
90
+ mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
91
+
92
+ # Batch mm inputs as much as we can: if a request in the batch has
93
+ # multiple modalities or a different modality than the previous one,
94
+ # we process it separately to preserve item order.
95
+ # FIXME(ywang96): This is a hacky way to deal with multiple modalities
96
+ # in the same batch while still being able to benefit from batching
97
+ # multimodal inputs. The proper solution should be reordering the
98
+ # encoder outputs.
99
+ encoder_outputs = []
100
+ for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
101
+ mm_kwargs, merge_by_field_config=False):
102
+ batched_mm_inputs = mm_kwargs_group
103
+ # Convert torch tensors to numpy arrays that JAX can handle.
104
+ if "pixel_values" in batched_mm_inputs and isinstance(
105
+ batched_mm_inputs["pixel_values"], list):
106
+ batched_mm_inputs["pixel_values"] = torch.cat(
107
+ batched_mm_inputs["pixel_values"], dim=0)
108
+
109
+ image_grid_thw = ()
110
+ for key, value in batched_mm_inputs.items():
111
+ if isinstance(value, torch.Tensor):
112
+ if key == 'image_grid_thw':
113
+ # change it to tuple of tuples to make it hashable for JIT
114
+
115
+ # Shape: (B, N, 3) -> (B*N, 3) -> tuple of tuples
116
+ grid_thw_tensor = batched_mm_inputs[key]
117
+ grid_thw_reshaped = grid_thw_tensor.reshape(-1, 3)
118
+ image_grid_thw = tuple(
119
+ tuple(row) for row in grid_thw_reshaped.tolist())
120
+
121
+ continue
122
+
123
+ if value.dtype == torch.bfloat16:
124
+ batched_mm_inputs[key] = value.to(
125
+ torch.float32).numpy().astype(jnp.bfloat16)
126
+ else:
127
+ batched_mm_inputs[key] = value.numpy()
128
+ batched_mm_inputs.pop('image_grid_thw')
129
+
130
+ # Run the encoder.
131
+ # `curr_group_outputs` is either of the following:
132
+ # 1. A tensor of shape (num_items, feature_size, hidden_size)
133
+ # in case feature_size is fixed across all multimodal items.
134
+ # 2. A list or tuple (length: num_items) of tensors, each of shape
135
+ # (feature_size, hidden_size) in case the feature size is dynamic
136
+ # depending on the input multimodal items.
137
+ curr_group_outputs = self.runner.get_multimodal_embeddings_fn(
138
+ self.runner.state, image_grid_thw, **batched_mm_inputs)
139
+
140
+ sanity_check_mm_encoder_outputs(
141
+ curr_group_outputs,
142
+ expected_num_items=num_items,
143
+ )
144
+
145
+ for output in curr_group_outputs:
146
+ encoder_outputs.append(output)
147
+
148
+ # Cache the encoder outputs.
149
+ for (mm_hash, pos_info), output in zip(
150
+ mm_hashes_pos,
151
+ encoder_outputs,
152
+ ):
153
+ if req_id not in self.runner.encoder_cache:
154
+ self.runner.encoder_cache[req_id] = {}
155
+
156
+ self.runner.encoder_cache[mm_hash] = scatter_mm_placeholders(
157
+ output,
158
+ is_embed=pos_info.is_embed,
159
+ )
160
+
161
+ def gather_mm_embeddings(self, scheduler_output: "VllmSchedulerOutput",
162
+ target_pad_len: int) -> list[jax.Array]:
163
+ mm_embeds: list[jax.Array] = []
164
+ for req_id in self.runner.input_batch.req_ids:
165
+ num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
166
+ req_id]
167
+ req_state = self.runner.requests[req_id]
168
+ num_computed_tokens = req_state.num_computed_tokens
169
+ mm_features = req_state.mm_features
170
+ for _, mm_feature in enumerate(mm_features):
171
+ pos_info = mm_feature.mm_position
172
+ start_pos = pos_info.offset
173
+ num_encoder_tokens = pos_info.length
174
+
175
+ # The encoder output is needed if the two ranges overlap:
176
+ # [num_computed_tokens,
177
+ # num_computed_tokens + num_scheduled_tokens) and
178
+ # [start_pos, start_pos + num_encoder_tokens)
179
+ if start_pos >= num_computed_tokens + num_scheduled_tokens:
180
+ # The encoder output is not needed in this step.
181
+ break
182
+ if start_pos + num_encoder_tokens <= num_computed_tokens:
183
+ # The encoder output is already processed and stored
184
+ # in the decoder's KV cache.
185
+ continue
186
+
187
+ start_idx = max(num_computed_tokens - start_pos, 0)
188
+ end_idx = min(
189
+ num_computed_tokens - start_pos + num_scheduled_tokens,
190
+ num_encoder_tokens)
191
+ assert start_idx < end_idx
192
+ mm_hash = mm_feature.identifier
193
+ encoder_output = self.runner.encoder_cache.get(mm_hash, None)
194
+ assert encoder_output is not None,\
195
+ f"Encoder cache miss for {mm_hash}."
196
+ encoder_output = self.runner.encoder_cache[mm_hash]
197
+
198
+ if (is_embed := pos_info.is_embed) is not None:
199
+ is_embed = is_embed[start_idx:end_idx]
200
+
201
+ mm_embeds_item = gather_mm_placeholders(
202
+ encoder_output[start_idx:end_idx],
203
+ is_embed=is_embed,
204
+ )
205
+ mm_embeds.append(mm_embeds_item)
206
+ if not mm_embeds:
207
+ return None
208
+ flattened_embeds = flatten_embeddings(mm_embeds)
209
+ if flattened_embeds.shape[0] == 0:
210
+ return None
211
+
212
+ padding = jnp.zeros((target_pad_len - flattened_embeds.shape[0],
213
+ flattened_embeds.shape[1]),
214
+ dtype=flattened_embeds.dtype)
215
+ flattened_embeds = jnp.concatenate([flattened_embeds, padding], axis=0)
216
+
217
+ return flattened_embeds
@@ -0,0 +1,282 @@
1
+ from typing import Dict
2
+
3
+ import jax
4
+ from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
5
+
6
+ from tpu_inference.logger import init_logger
7
+ from tpu_inference.runner.input_batch import CachedRequestState, InputBatch
8
+
9
+ logger = init_logger(__name__)
10
+
11
+
12
+ class PersistentBatchManager:
13
+
14
+ def __init__(self, requests: Dict[str, CachedRequestState],
15
+ input_batch: InputBatch, encoder_cache: Dict[str,
16
+ 'jax.Array'],
17
+ uses_mrope: bool, model_config, is_last_rank: bool):
18
+ self.requests = requests
19
+ self.input_batch = input_batch
20
+ self.encoder_cache = encoder_cache
21
+ self.uses_mrope = uses_mrope
22
+ self.model_config = model_config
23
+ self.is_last_rank = is_last_rank
24
+
25
+ def _reorder_batch(self, scheduler_output: "VllmSchedulerOutput") -> int:
26
+ """ Reorder the sheduled requests to RPA kernel friendly distribution
27
+ (decode_only, fixed_chunked_prefill_only, mixed) and set the request
28
+ distribution accordingly.
29
+
30
+ Returns:
31
+ The number of swaps in requests.
32
+ """
33
+ # Note(jevinjiang): currently we only consider decode_only.
34
+ num_reqs = self.input_batch.num_reqs
35
+ swap_cnt = 0
36
+ if num_reqs <= 0:
37
+ return swap_cnt
38
+ # Use two-pointer approach to reorder the decode requests to front.
39
+ i, j = 0, num_reqs - 1
40
+ while i < j:
41
+ i_req_id = self.input_batch.req_ids[i]
42
+ j_req_id = self.input_batch.req_ids[j]
43
+
44
+ if scheduler_output.num_scheduled_tokens[i_req_id] == 1:
45
+ # i is a decode request, move to the next one.
46
+ i += 1
47
+ elif scheduler_output.num_scheduled_tokens[j_req_id] > 1:
48
+ # j is a prefill request, move to the previous one.
49
+ j -= 1
50
+ else:
51
+ # Swap i and j.
52
+ self.input_batch.swap_states(i, j)
53
+ i += 1
54
+ j -= 1
55
+ swap_cnt += 1
56
+
57
+ num_decode = i + int(scheduler_output.num_scheduled_tokens[
58
+ self.input_batch.req_ids[i]] == 1)
59
+
60
+ self.input_batch.request_distribution = [
61
+ num_decode, num_decode, num_reqs
62
+ ]
63
+
64
+ return swap_cnt
65
+
66
+ def update_states(self, scheduler_output: "VllmSchedulerOutput",
67
+ get_mrope_input_positions_fn) -> bool:
68
+ """Update the cached states and the persistent batch with the scheduler
69
+ output.
70
+
71
+ The updated states are used by the `_prepare_inputs` function to create
72
+ the input TPU tensors for the model.
73
+
74
+ Returns:
75
+ True if there is a new/resumed/paused/finished request.
76
+ If False, we can skip copying SamplingMetadata to the TPU.
77
+ """
78
+ # Remove finished requests from the cached states.
79
+ for req_id in scheduler_output.finished_req_ids:
80
+ self.requests.pop(req_id, None)
81
+
82
+ # Remove the finished requests from the persistent batch.
83
+ # NOTE(woosuk): There could be an edge case where finished_req_ids and
84
+ # scheduled_req_ids overlap. This happens when a request is aborted and
85
+ # then resubmitted with the same ID. In this case, we treat them as two
86
+ # distinct requests - clearing the cached states for the first request
87
+ # and handling the second as a new request.
88
+ removed_req_indices: list[int] = []
89
+ for req_id in scheduler_output.finished_req_ids:
90
+ req_index = self.input_batch.remove_request(req_id)
91
+ if req_index is not None:
92
+ removed_req_indices.append(req_index)
93
+
94
+ # Free the cached encoder outputs.
95
+ for mm_hash in scheduler_output.free_encoder_mm_hashes:
96
+ self.encoder_cache.pop(mm_hash, None)
97
+
98
+ # Remove the unscheduled requests from the persistent batch.
99
+ # NOTE(woosuk): The unscheduled requests are either preempted requests
100
+ # or running requests that are not scheduled in this step. We remove
101
+ # them from the persistent batch but keep their cached states since
102
+ # they will be scheduled again sometime in the future.
103
+ scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
104
+ cached_req_ids = self.input_batch.req_id_to_index.keys()
105
+ unscheduled_req_ids = cached_req_ids - scheduled_req_ids
106
+ # NOTE(woosuk): The persistent batch optimization assumes that
107
+ # consecutive batches contain mostly the same requests. If batches
108
+ # have low request overlap (e.g., alternating between two distinct
109
+ # sets of requests), this optimization becomes very inefficient.
110
+ for req_id in unscheduled_req_ids:
111
+ req_index = self.input_batch.remove_request(req_id)
112
+ assert req_index is not None
113
+ removed_req_indices.append(req_index)
114
+
115
+ req_ids_to_add: list[str] = []
116
+ # Add new requests to the cached states.
117
+ for new_req_data in scheduler_output.scheduled_new_reqs:
118
+ req_id = new_req_data.req_id
119
+ sampling_params = new_req_data.sampling_params
120
+
121
+ self.requests[req_id] = CachedRequestState(
122
+ req_id=req_id,
123
+ prompt_token_ids=new_req_data.prompt_token_ids,
124
+ mm_features=new_req_data.mm_features,
125
+ sampling_params=sampling_params,
126
+ pooling_params=None,
127
+ generator=None,
128
+ block_ids=new_req_data.block_ids,
129
+ num_computed_tokens=new_req_data.num_computed_tokens,
130
+ output_token_ids=[],
131
+ lora_request=new_req_data.lora_request,
132
+ )
133
+
134
+ req_ids_to_add.append(req_id)
135
+
136
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
137
+ if self.uses_mrope:
138
+ image_grid_thw = []
139
+ video_grid_thw = []
140
+ second_per_grid_ts = []
141
+ audio_feature_lengths = []
142
+ use_audio_in_video = False
143
+ for mm_feature in self.requests[req_id].mm_features:
144
+ item = mm_feature.data
145
+ if item is None:
146
+ continue
147
+ mm_input = item.get_data()
148
+ if mm_input.get("image_grid_thw") is not None:
149
+ image_grid_thw.append(
150
+ mm_input["image_grid_thw"].tolist())
151
+ if mm_input.get("video_grid_thw") is not None:
152
+ video_grid_thw.append(
153
+ mm_input["video_grid_thw"].tolist())
154
+ if mm_input.get("second_per_grid_ts") is not None:
155
+ second_per_grid_ts.append(
156
+ mm_input["second_per_grid_ts"])
157
+ if mm_input.get("audio_feature_lengths") is not None:
158
+ audio_feature_lengths.append(
159
+ mm_input["audio_feature_lengths"])
160
+ if mm_input.get("use_audio_in_video") is True:
161
+ use_audio_in_video = True
162
+
163
+ hf_config = self.model_config.hf_config
164
+
165
+ self.requests[req_id].mrope_positions, self.requests[
166
+ req_id].mrope_position_delta = get_mrope_input_positions_fn(
167
+ self.requests[req_id].prompt_token_ids,
168
+ hf_config=hf_config,
169
+ image_grid_thw=image_grid_thw,
170
+ video_grid_thw=video_grid_thw,
171
+ second_per_grid_ts=second_per_grid_ts,
172
+ audio_feature_lengths=audio_feature_lengths,
173
+ use_audio_in_video=use_audio_in_video,
174
+ )
175
+
176
+ # Update the states of the running/resumed requests.
177
+ req_data = scheduler_output.scheduled_cached_reqs
178
+ for i, req_id in enumerate(req_data.req_ids):
179
+ req_state = self.requests[req_id]
180
+ num_computed_tokens = req_data.num_computed_tokens[i]
181
+ new_block_ids = req_data.new_block_ids[i]
182
+ resumed_from_preemption = req_data.resumed_from_preemption[i]
183
+ num_output_tokens = req_data.num_output_tokens[i]
184
+
185
+ # Update the cached states.
186
+ req_state.num_computed_tokens = num_computed_tokens
187
+ req_index = self.input_batch.req_id_to_index.get(req_id)
188
+
189
+ if not self.is_last_rank:
190
+ # When using PP, the scheduler sends the sampled tokens back,
191
+ # because there's no direct communication between the first-
192
+ # stage worker and the last-stage worker.
193
+ new_token_ids = req_data.new_token_ids[i]
194
+ # Add the sampled token(s) from the previous step (if any).
195
+ # This doesn't include "unverified" tokens like spec tokens.
196
+ num_new_tokens = (num_computed_tokens + len(new_token_ids) -
197
+ req_state.num_tokens)
198
+ if num_new_tokens == 1:
199
+ req_state.output_token_ids.append(new_token_ids[-1])
200
+ elif num_new_tokens > 0:
201
+ req_state.output_token_ids.extend(
202
+ new_token_ids[-num_new_tokens:])
203
+ elif num_output_tokens < len(req_state.output_token_ids):
204
+ del req_state.output_token_ids[num_output_tokens:]
205
+ if req_index is not None:
206
+ end_idx = (self.input_batch.num_prompt_tokens[req_index] +
207
+ num_output_tokens)
208
+ self.input_batch.num_tokens[req_index] = end_idx
209
+ self.input_batch.num_tokens_no_spec[req_index] = end_idx
210
+
211
+ # Update the block IDs.
212
+ if not resumed_from_preemption:
213
+ if new_block_ids is not None:
214
+ # Append the new blocks to the existing block IDs.
215
+ for block_ids, new_ids in zip(req_state.block_ids,
216
+ new_block_ids):
217
+ block_ids.extend(new_ids)
218
+ else:
219
+ assert new_block_ids is not None
220
+ # The request is resumed from preemption.
221
+ # Replace the existing block IDs with the new ones.
222
+ req_state.block_ids = new_block_ids
223
+
224
+ if req_index is None:
225
+ # The request is not in the persistent batch.
226
+ # The request was either preempted and resumed later, or was not
227
+ # scheduled in the previous step and needs to be added again.
228
+ req_ids_to_add.append(req_id)
229
+ continue
230
+
231
+ # Update the persistent batch.
232
+ self.input_batch.num_computed_tokens_cpu[
233
+ req_index] = num_computed_tokens
234
+ if new_block_ids is not None:
235
+ self.input_batch.block_table.append_row(
236
+ new_block_ids, req_index)
237
+
238
+ # For the last rank, we don't need to update the token_ids_cpu
239
+ # because the sampled tokens are already cached.
240
+ if not self.is_last_rank:
241
+ start_token_index = num_computed_tokens
242
+ end_token_index = num_computed_tokens + len(new_token_ids)
243
+ self.input_batch.token_ids_cpu[
244
+ req_index,
245
+ start_token_index:end_token_index] = new_token_ids
246
+ self.input_batch.num_tokens_no_spec[
247
+ req_index] = end_token_index
248
+ self.input_batch.num_tokens[req_index] = end_token_index
249
+
250
+ # Add spec_token_ids to token_ids_cpu.
251
+ spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
252
+ req_id, ())
253
+ if spec_token_ids:
254
+ num_spec_tokens = len(spec_token_ids)
255
+ start_index = self.input_batch.num_tokens_no_spec[req_index]
256
+ end_token_index = start_index + num_spec_tokens
257
+ self.input_batch.token_ids_cpu[
258
+ req_index, start_index:end_token_index] = spec_token_ids
259
+ # NOTE(woosuk): `num_tokens` here may include spec tokens.
260
+ self.input_batch.num_tokens[req_index] += num_spec_tokens
261
+
262
+ # Add the new or resumed requests to the persistent batch.
263
+ # The smaller empty indices are filled first.
264
+ removed_req_indices = sorted(removed_req_indices, reverse=True)
265
+ for req_id in req_ids_to_add:
266
+ req_state = self.requests[req_id]
267
+ if removed_req_indices:
268
+ # Fill the empty index.
269
+ req_index = removed_req_indices.pop()
270
+ else:
271
+ # Append to the end.
272
+ req_index = None
273
+ self.input_batch.add_request(req_state, req_index)
274
+
275
+ # Condense the batched states if there are empty indices.
276
+ if removed_req_indices:
277
+ self.input_batch.condense(removed_req_indices)
278
+
279
+ batch_changed = len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
280
+ # TODO(jevinjiang): I assume we do not need to set batch_changed to true if just swapping requests.
281
+ self._reorder_batch(scheduler_output)
282
+ return batch_changed