tpu-inference 0.11.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tpu-inference might be problematic. Click here for more details.

Files changed (168) hide show
  1. tests/__init__.py +0 -0
  2. tests/core/__init__.py +0 -0
  3. tests/core/test_adapters.py +83 -0
  4. tests/core/test_core_tpu.py +523 -0
  5. tests/core/test_disagg_executor.py +60 -0
  6. tests/core/test_disagg_utils.py +53 -0
  7. tests/core/test_init.py +49 -0
  8. tests/kernels/__init__.py +0 -0
  9. tests/kernels/quantized_matmul_kernel_test.py +191 -0
  10. tests/kernels/ragged_kv_cache_update_v2_test.py +234 -0
  11. tests/kernels/ragged_paged_attention_kernel_v2_test.py +400 -0
  12. tests/kernels/ragged_paged_attention_kernel_v3_test.py +504 -0
  13. tests/lora/__init__.py +0 -0
  14. tests/lora/test_lora.py +123 -0
  15. tests/test_base.py +201 -0
  16. tests/test_quantization.py +836 -0
  17. tests/test_tpu_info.py +120 -0
  18. tests/test_utils.py +218 -0
  19. tests/tpu_backend_test.py +59 -0
  20. tpu_inference/__init__.py +30 -0
  21. tpu_inference/adapters/__init__.py +0 -0
  22. tpu_inference/adapters/vllm_adapters.py +42 -0
  23. tpu_inference/adapters/vllm_config_adapters.py +134 -0
  24. tpu_inference/backend.py +69 -0
  25. tpu_inference/core/__init__.py +0 -0
  26. tpu_inference/core/adapters.py +153 -0
  27. tpu_inference/core/core_tpu.py +776 -0
  28. tpu_inference/core/disagg_executor.py +117 -0
  29. tpu_inference/core/disagg_utils.py +51 -0
  30. tpu_inference/di/__init__.py +0 -0
  31. tpu_inference/di/abstracts.py +28 -0
  32. tpu_inference/di/host.py +76 -0
  33. tpu_inference/di/interfaces.py +51 -0
  34. tpu_inference/distributed/__init__.py +0 -0
  35. tpu_inference/distributed/tpu_connector.py +699 -0
  36. tpu_inference/distributed/utils.py +59 -0
  37. tpu_inference/executors/__init__.py +0 -0
  38. tpu_inference/executors/ray_distributed_executor.py +346 -0
  39. tpu_inference/experimental/__init__.py +0 -0
  40. tpu_inference/experimental/llama3_jax_stashed.py +258 -0
  41. tpu_inference/interfaces/__init__.py +0 -0
  42. tpu_inference/interfaces/cache.py +31 -0
  43. tpu_inference/interfaces/config.py +47 -0
  44. tpu_inference/interfaces/config_parts.py +117 -0
  45. tpu_inference/interfaces/engine.py +51 -0
  46. tpu_inference/interfaces/outputs.py +22 -0
  47. tpu_inference/interfaces/params.py +21 -0
  48. tpu_inference/interfaces/platform.py +74 -0
  49. tpu_inference/interfaces/request.py +39 -0
  50. tpu_inference/interfaces/scheduler.py +31 -0
  51. tpu_inference/kernels/__init__.py +0 -0
  52. tpu_inference/kernels/collectives/__init__.py +0 -0
  53. tpu_inference/kernels/collectives/all_gather_matmul.py +735 -0
  54. tpu_inference/kernels/collectives/all_gather_matmul_tuned_block_sizes.py +60 -0
  55. tpu_inference/kernels/collectives/util.py +47 -0
  56. tpu_inference/kernels/flash_attention/__init__.py +0 -0
  57. tpu_inference/kernels/flash_attention/kernel.py +772 -0
  58. tpu_inference/kernels/quantized_matmul/__init__.py +0 -0
  59. tpu_inference/kernels/quantized_matmul/kernel.py +395 -0
  60. tpu_inference/kernels/quantized_matmul/tuned_block_sizes.py +609 -0
  61. tpu_inference/kernels/quantized_matmul/util.py +58 -0
  62. tpu_inference/kernels/ragged_paged_attention/__init__.py +0 -0
  63. tpu_inference/kernels/ragged_paged_attention/v2/__init__.py +0 -0
  64. tpu_inference/kernels/ragged_paged_attention/v2/kernel.py +875 -0
  65. tpu_inference/kernels/ragged_paged_attention/v2/ragged_kv_cache_update.py +287 -0
  66. tpu_inference/kernels/ragged_paged_attention/v2/tuned_block_sizes.py +1482 -0
  67. tpu_inference/kernels/ragged_paged_attention/v3/__init__.py +0 -0
  68. tpu_inference/kernels/ragged_paged_attention/v3/kernel.py +1447 -0
  69. tpu_inference/kernels/ragged_paged_attention/v3/tuned_block_sizes.py +3834 -0
  70. tpu_inference/kernels/ragged_paged_attention/v3/util.py +47 -0
  71. tpu_inference/layers/__init__.py +0 -0
  72. tpu_inference/layers/common/__init__.py +0 -0
  73. tpu_inference/layers/common/attention_metadata.py +34 -0
  74. tpu_inference/layers/jax/__init__.py +0 -0
  75. tpu_inference/layers/jax/attention/__init__.py +0 -0
  76. tpu_inference/layers/jax/attention/attention.py +254 -0
  77. tpu_inference/layers/jax/attention/deepseek_v3_attention.py +354 -0
  78. tpu_inference/layers/jax/attention/llama4_attention.py +153 -0
  79. tpu_inference/layers/jax/attention_interface.py +356 -0
  80. tpu_inference/layers/jax/base.py +151 -0
  81. tpu_inference/layers/jax/binary_search.py +295 -0
  82. tpu_inference/layers/jax/constants.py +88 -0
  83. tpu_inference/layers/jax/layers.py +301 -0
  84. tpu_inference/layers/jax/misc.py +16 -0
  85. tpu_inference/layers/jax/moe/__init__.py +0 -0
  86. tpu_inference/layers/jax/moe/deepseek_v3_moe.py +608 -0
  87. tpu_inference/layers/jax/moe/moe.py +209 -0
  88. tpu_inference/layers/jax/rope.py +172 -0
  89. tpu_inference/layers/jax/rope_interface.py +214 -0
  90. tpu_inference/layers/jax/sample/__init__.py +0 -0
  91. tpu_inference/layers/jax/sample/rejection_sampler.py +515 -0
  92. tpu_inference/layers/jax/sample/sampling.py +95 -0
  93. tpu_inference/layers/jax/sample/sampling_metadata.py +69 -0
  94. tpu_inference/layers/jax/sharding.py +406 -0
  95. tpu_inference/layers/jax/transformer_block.py +76 -0
  96. tpu_inference/layers/vllm/__init__.py +0 -0
  97. tpu_inference/layers/vllm/attention.py +184 -0
  98. tpu_inference/layers/vllm/fused_moe.py +399 -0
  99. tpu_inference/layers/vllm/linear_common.py +186 -0
  100. tpu_inference/layers/vllm/quantization/__init__.py +34 -0
  101. tpu_inference/layers/vllm/quantization/awq.py +207 -0
  102. tpu_inference/layers/vllm/quantization/common.py +105 -0
  103. tpu_inference/layers/vllm/quantization/compressed_tensors/__init__.py +0 -0
  104. tpu_inference/layers/vllm/quantization/compressed_tensors/compressed_tensors.py +121 -0
  105. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/__init__.py +0 -0
  106. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +208 -0
  107. tpu_inference/layers/vllm/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +136 -0
  108. tpu_inference/layers/vllm/quantization/unquantized.py +263 -0
  109. tpu_inference/layers/vllm/sharding.py +151 -0
  110. tpu_inference/logger.py +10 -0
  111. tpu_inference/lora/__init__.py +0 -0
  112. tpu_inference/lora/torch_lora_ops.py +103 -0
  113. tpu_inference/lora/torch_punica_tpu.py +308 -0
  114. tpu_inference/mock/__init__.py +0 -0
  115. tpu_inference/mock/vllm_config_utils.py +28 -0
  116. tpu_inference/mock/vllm_envs.py +1233 -0
  117. tpu_inference/mock/vllm_logger.py +212 -0
  118. tpu_inference/mock/vllm_logging_utils.py +15 -0
  119. tpu_inference/models/__init__.py +0 -0
  120. tpu_inference/models/common/__init__.py +0 -0
  121. tpu_inference/models/common/model_loader.py +433 -0
  122. tpu_inference/models/jax/__init__.py +0 -0
  123. tpu_inference/models/jax/deepseek_v3.py +868 -0
  124. tpu_inference/models/jax/llama3.py +366 -0
  125. tpu_inference/models/jax/llama4.py +473 -0
  126. tpu_inference/models/jax/llama_eagle3.py +333 -0
  127. tpu_inference/models/jax/phi3.py +376 -0
  128. tpu_inference/models/jax/qwen2.py +375 -0
  129. tpu_inference/models/jax/qwen2_5_vl.py +976 -0
  130. tpu_inference/models/jax/qwen3.py +302 -0
  131. tpu_inference/models/jax/utils/__init__.py +0 -0
  132. tpu_inference/models/jax/utils/file_utils.py +96 -0
  133. tpu_inference/models/jax/utils/multi_modal_utils.py +164 -0
  134. tpu_inference/models/jax/utils/quantization/__init__.py +0 -0
  135. tpu_inference/models/jax/utils/quantization/quantization_utils.py +588 -0
  136. tpu_inference/models/jax/utils/weight_utils.py +510 -0
  137. tpu_inference/models/vllm/__init__.py +0 -0
  138. tpu_inference/models/vllm/vllm_model_wrapper.py +272 -0
  139. tpu_inference/models/vllm/vllm_model_wrapper_context.py +45 -0
  140. tpu_inference/platforms/__init__.py +2 -0
  141. tpu_inference/platforms/tpu_jax.py +257 -0
  142. tpu_inference/runner/__init__.py +0 -0
  143. tpu_inference/runner/block_table_jax.py +122 -0
  144. tpu_inference/runner/compilation_manager.py +672 -0
  145. tpu_inference/runner/input_batch_jax.py +435 -0
  146. tpu_inference/runner/kv_cache.py +119 -0
  147. tpu_inference/runner/kv_cache_manager.py +460 -0
  148. tpu_inference/runner/lora_utils.py +92 -0
  149. tpu_inference/runner/multimodal_manager.py +208 -0
  150. tpu_inference/runner/persistent_batch_manager.py +244 -0
  151. tpu_inference/runner/speculative_decoding_manager.py +250 -0
  152. tpu_inference/runner/structured_decoding_manager.py +89 -0
  153. tpu_inference/runner/tpu_jax_runner.py +771 -0
  154. tpu_inference/runner/utils.py +426 -0
  155. tpu_inference/spec_decode/__init__.py +0 -0
  156. tpu_inference/spec_decode/jax/__init__.py +0 -0
  157. tpu_inference/spec_decode/jax/eagle3.py +334 -0
  158. tpu_inference/tpu_info.py +77 -0
  159. tpu_inference/utils.py +294 -0
  160. tpu_inference/worker/__init__.py +0 -0
  161. tpu_inference/worker/_temporary_vllm_compat.py +129 -0
  162. tpu_inference/worker/base.py +100 -0
  163. tpu_inference/worker/tpu_worker_jax.py +321 -0
  164. tpu_inference-0.11.1.dist-info/METADATA +101 -0
  165. tpu_inference-0.11.1.dist-info/RECORD +168 -0
  166. tpu_inference-0.11.1.dist-info/WHEEL +5 -0
  167. tpu_inference-0.11.1.dist-info/licenses/LICENSE +201 -0
  168. tpu_inference-0.11.1.dist-info/top_level.txt +2 -0
@@ -0,0 +1,208 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
6
+ from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
7
+ from vllm.multimodal.utils import group_mm_kwargs_by_modality
8
+ from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
9
+ from vllm.v1.worker.utils import (gather_mm_placeholders,
10
+ scatter_mm_placeholders)
11
+
12
+ from tpu_inference.models.jax.utils.multi_modal_utils import \
13
+ sanity_check_mm_encoder_outputs
14
+
15
+ if TYPE_CHECKING:
16
+ from tpu_inference.runner.tpu_jax_runner import TPUModelRunner
17
+
18
+
19
+ class MultiModalManager:
20
+
21
+ def __init__(self, runner: "TPUModelRunner"):
22
+ self.runner = runner
23
+
24
+ def calc_mrope_positions(self, scheduler_output: "VllmSchedulerOutput"):
25
+ mrope_pos_ptr = 0
26
+ for index, req_id in enumerate(self.runner.input_batch.req_ids):
27
+ req = self.runner.requests[req_id]
28
+ assert req.mrope_positions is not None
29
+
30
+ num_computed_tokens = \
31
+ self.runner.input_batch.num_computed_tokens_cpu[index]
32
+ num_scheduled_tokens = \
33
+ scheduler_output.num_scheduled_tokens[req_id]
34
+ num_prompt_tokens = len(req.prompt_token_ids)
35
+
36
+ if num_computed_tokens + num_scheduled_tokens > num_prompt_tokens:
37
+ prompt_part_len = max(0,
38
+ num_prompt_tokens - num_computed_tokens)
39
+ completion_part_len = max(
40
+ 0, num_scheduled_tokens - prompt_part_len)
41
+ else:
42
+ prompt_part_len = num_scheduled_tokens
43
+ completion_part_len = 0
44
+
45
+ assert num_scheduled_tokens == prompt_part_len + completion_part_len
46
+
47
+ if prompt_part_len > 0:
48
+ # prompt's mrope_positions are pre-computed
49
+ dst_start = mrope_pos_ptr
50
+ dst_end = mrope_pos_ptr + prompt_part_len
51
+ src_start = num_computed_tokens
52
+ src_end = num_computed_tokens + prompt_part_len
53
+
54
+ self.runner.mrope_positions_cpu[:, dst_start:dst_end] = \
55
+ req.mrope_positions[:,src_start:src_end]
56
+
57
+ mrope_pos_ptr += prompt_part_len
58
+
59
+ if completion_part_len > 0:
60
+ # compute completion's mrope_positions on-the-fly
61
+ dst_start = mrope_pos_ptr
62
+ dst_end = mrope_pos_ptr + completion_part_len
63
+
64
+ MRotaryEmbedding.get_next_input_positions_tensor(
65
+ out=self.runner.mrope_positions_cpu,
66
+ out_offset=dst_start,
67
+ mrope_position_delta=req.mrope_position_delta,
68
+ context_len=num_computed_tokens + prompt_part_len,
69
+ num_new_tokens=completion_part_len,
70
+ )
71
+
72
+ mrope_pos_ptr += completion_part_len
73
+
74
+ def execute_mm_encoder(self, scheduler_output: "VllmSchedulerOutput"):
75
+ import torch
76
+ scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
77
+ if not scheduled_encoder_inputs:
78
+ return
79
+
80
+ # Batch the multi-modal inputs.
81
+ mm_kwargs = list[MultiModalKwargsItem]()
82
+ # List of tuple (mm_hash, pos_info)
83
+ mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
84
+ for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
85
+ req_state = self.runner.requests[req_id]
86
+ for mm_input_id in encoder_input_ids:
87
+ mm_feature = req_state.mm_features[mm_input_id]
88
+ mm_hash = mm_feature.identifier
89
+ mm_kwargs.append(mm_feature.data)
90
+ mm_hashes_pos.append((mm_hash, mm_feature.mm_position))
91
+
92
+ # Batch mm inputs as much as we can: if a request in the batch has
93
+ # multiple modalities or a different modality than the previous one,
94
+ # we process it separately to preserve item order.
95
+ # FIXME(ywang96): This is a hacky way to deal with multiple modalities
96
+ # in the same batch while still being able to benefit from batching
97
+ # multimodal inputs. The proper solution should be reordering the
98
+ # encoder outputs.
99
+ encoder_outputs = []
100
+ for _, num_items, mm_kwargs_group in group_mm_kwargs_by_modality(
101
+ mm_kwargs, merge_by_field_config=False):
102
+ batched_mm_inputs = mm_kwargs_group
103
+ # Convert torch tensors to numpy arrays that JAX can handle.
104
+ if "pixel_values" in batched_mm_inputs and isinstance(
105
+ batched_mm_inputs["pixel_values"], list):
106
+ batched_mm_inputs["pixel_values"] = torch.cat(
107
+ batched_mm_inputs["pixel_values"], dim=0)
108
+
109
+ image_grid_thw = ()
110
+ for key, value in batched_mm_inputs.items():
111
+ if isinstance(value, torch.Tensor):
112
+ if key == 'image_grid_thw':
113
+ # change it to tuple of tuples to make it hashable for JIT
114
+
115
+ # Shape: (B, N, 3) -> (B*N, 3) -> tuple of tuples
116
+ grid_thw_tensor = batched_mm_inputs[key]
117
+ grid_thw_reshaped = grid_thw_tensor.reshape(-1, 3)
118
+ image_grid_thw = tuple(
119
+ tuple(row) for row in grid_thw_reshaped.tolist())
120
+
121
+ continue
122
+
123
+ if value.dtype == torch.bfloat16:
124
+ batched_mm_inputs[key] = value.to(
125
+ torch.float32).numpy().astype(jnp.bfloat16)
126
+ else:
127
+ batched_mm_inputs[key] = value.numpy()
128
+ batched_mm_inputs.pop('image_grid_thw')
129
+
130
+ # Run the encoder.
131
+ # `curr_group_outputs` is either of the following:
132
+ # 1. A tensor of shape (num_items, feature_size, hidden_size)
133
+ # in case feature_size is fixed across all multimodal items.
134
+ # 2. A list or tuple (length: num_items) of tensors, each of shape
135
+ # (feature_size, hidden_size) in case the feature size is dynamic
136
+ # depending on the input multimodal items.
137
+ curr_group_outputs = self.runner.get_multimodal_embeddings_fn(
138
+ self.runner.state, image_grid_thw, **batched_mm_inputs)
139
+
140
+ sanity_check_mm_encoder_outputs(
141
+ curr_group_outputs,
142
+ expected_num_items=num_items,
143
+ )
144
+
145
+ for output in curr_group_outputs:
146
+ encoder_outputs.append(output)
147
+
148
+ # Cache the encoder outputs.
149
+ for (mm_hash, pos_info), output in zip(
150
+ mm_hashes_pos,
151
+ encoder_outputs,
152
+ ):
153
+ if req_id not in self.runner.encoder_cache:
154
+ self.runner.encoder_cache[req_id] = {}
155
+
156
+ self.runner.encoder_cache[mm_hash] = scatter_mm_placeholders(
157
+ output,
158
+ is_embed=pos_info.is_embed,
159
+ )
160
+
161
+ def gather_mm_embeddings(
162
+ self,
163
+ scheduler_output: "VllmSchedulerOutput",
164
+ ) -> list[jax.Array]:
165
+ mm_embeds: list[jax.Array] = []
166
+ for req_id in self.runner.input_batch.req_ids:
167
+ num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
168
+ req_id]
169
+ req_state = self.runner.requests[req_id]
170
+ num_computed_tokens = req_state.num_computed_tokens
171
+ mm_features = req_state.mm_features
172
+ for _, mm_feature in enumerate(mm_features):
173
+ pos_info = mm_feature.mm_position
174
+ start_pos = pos_info.offset
175
+ num_encoder_tokens = pos_info.length
176
+
177
+ # The encoder output is needed if the two ranges overlap:
178
+ # [num_computed_tokens,
179
+ # num_computed_tokens + num_scheduled_tokens) and
180
+ # [start_pos, start_pos + num_encoder_tokens)
181
+ if start_pos >= num_computed_tokens + num_scheduled_tokens:
182
+ # The encoder output is not needed in this step.
183
+ break
184
+ if start_pos + num_encoder_tokens <= num_computed_tokens:
185
+ # The encoder output is already processed and stored
186
+ # in the decoder's KV cache.
187
+ continue
188
+
189
+ start_idx = max(num_computed_tokens - start_pos, 0)
190
+ end_idx = min(
191
+ num_computed_tokens - start_pos + num_scheduled_tokens,
192
+ num_encoder_tokens)
193
+ assert start_idx < end_idx
194
+ mm_hash = mm_feature.identifier
195
+ encoder_output = self.runner.encoder_cache.get(mm_hash, None)
196
+ assert encoder_output is not None,\
197
+ f"Encoder cache miss for {mm_hash}."
198
+ encoder_output = self.runner.encoder_cache[mm_hash]
199
+
200
+ if (is_embed := pos_info.is_embed) is not None:
201
+ is_embed = is_embed[start_idx:end_idx]
202
+
203
+ mm_embeds_item = gather_mm_placeholders(
204
+ encoder_output[start_idx:end_idx],
205
+ is_embed=is_embed,
206
+ )
207
+ mm_embeds.append(mm_embeds_item)
208
+ return mm_embeds
@@ -0,0 +1,244 @@
1
+ from typing import Dict
2
+
3
+ import jax
4
+ from vllm.v1.core.sched.output import SchedulerOutput as VllmSchedulerOutput
5
+
6
+ from tpu_inference.logger import init_logger
7
+ from tpu_inference.runner.input_batch_jax import CachedRequestState, InputBatch
8
+
9
+ logger = init_logger(__name__)
10
+
11
+
12
+ class PersistentBatchManager:
13
+
14
+ def __init__(self, requests: Dict[str, CachedRequestState],
15
+ input_batch: InputBatch, encoder_cache: Dict[str,
16
+ 'jax.Array'],
17
+ uses_mrope: bool, model_config):
18
+ self.requests = requests
19
+ self.input_batch = input_batch
20
+ self.encoder_cache = encoder_cache
21
+ self.uses_mrope = uses_mrope
22
+ self.model_config = model_config
23
+
24
+ def _reorder_batch(self, scheduler_output: "VllmSchedulerOutput") -> int:
25
+ """ Reorder the sheduled requests to RPA kernel friendly distribution
26
+ (decode_only, fixed_chunked_prefill_only, mixed) and set the request
27
+ distribution accordingly.
28
+
29
+ Returns:
30
+ The number of swaps in requests.
31
+ """
32
+ # Note(jevinjiang): currently we only consider decode_only.
33
+ num_reqs = self.input_batch.num_reqs
34
+ swap_cnt = 0
35
+ if num_reqs <= 0:
36
+ return swap_cnt
37
+ # Use two-pointer approach to reorder the decode requests to front.
38
+ i, j = 0, num_reqs - 1
39
+ while i < j:
40
+ i_req_id = self.input_batch.req_ids[i]
41
+ j_req_id = self.input_batch.req_ids[j]
42
+
43
+ if scheduler_output.num_scheduled_tokens[i_req_id] == 1:
44
+ # i is a decode request, move to the next one.
45
+ i += 1
46
+ elif scheduler_output.num_scheduled_tokens[j_req_id] > 1:
47
+ # j is a prefill request, move to the previous one.
48
+ j -= 1
49
+ else:
50
+ # Swap i and j.
51
+ self.input_batch.swap_states(i, j)
52
+ i += 1
53
+ j -= 1
54
+ swap_cnt += 1
55
+
56
+ num_decode = i + int(scheduler_output.num_scheduled_tokens[
57
+ self.input_batch.req_ids[i]] == 1)
58
+
59
+ self.input_batch.request_distribution = [
60
+ num_decode, num_decode, num_reqs
61
+ ]
62
+
63
+ return swap_cnt
64
+
65
+ def update_states(self, scheduler_output: "VllmSchedulerOutput",
66
+ get_mrope_input_positions_fn) -> bool:
67
+ """Update the cached states and the persistent batch with the scheduler
68
+ output.
69
+
70
+ The updated states are used by the `_prepare_inputs` function to create
71
+ the input TPU tensors for the model.
72
+
73
+ Returns:
74
+ True if there is a new/resumed/paused/finished request.
75
+ If False, we can skip copying SamplingMetadata to the TPU.
76
+ """
77
+ # Remove finished requests from the cached states.
78
+ for req_id in scheduler_output.finished_req_ids:
79
+ self.requests.pop(req_id, None)
80
+
81
+ # Remove the finished requests from the persistent batch.
82
+ # NOTE(woosuk): There could be an edge case where finished_req_ids and
83
+ # scheduled_req_ids overlap. This happens when a request is aborted and
84
+ # then resubmitted with the same ID. In this case, we treat them as two
85
+ # distinct requests - clearing the cached states for the first request
86
+ # and handling the second as a new request.
87
+ removed_req_indices: list[int] = []
88
+ for req_id in scheduler_output.finished_req_ids:
89
+ req_index = self.input_batch.remove_request(req_id)
90
+ if req_index is not None:
91
+ removed_req_indices.append(req_index)
92
+
93
+ # Free the cached encoder outputs.
94
+ for mm_hash in scheduler_output.free_encoder_mm_hashes:
95
+ self.encoder_cache.pop(mm_hash, None)
96
+
97
+ # Remove the unscheduled requests from the persistent batch.
98
+ # NOTE(woosuk): The unscheduled requests are either preempted requests
99
+ # or running requests that are not scheduled in this step. We remove
100
+ # them from the persistent batch but keep their cached states since
101
+ # they will be scheduled again sometime in the future.
102
+ scheduled_req_ids = scheduler_output.num_scheduled_tokens.keys()
103
+ cached_req_ids = self.input_batch.req_id_to_index.keys()
104
+ unscheduled_req_ids = cached_req_ids - scheduled_req_ids
105
+ # NOTE(woosuk): The persistent batch optimization assumes that
106
+ # consecutive batches contain mostly the same requests. If batches
107
+ # have low request overlap (e.g., alternating between two distinct
108
+ # sets of requests), this optimization becomes very inefficient.
109
+ for req_id in unscheduled_req_ids:
110
+ req_index = self.input_batch.remove_request(req_id)
111
+ assert req_index is not None
112
+ removed_req_indices.append(req_index)
113
+
114
+ req_ids_to_add: list[str] = []
115
+ # Add new requests to the cached states.
116
+ for new_req_data in scheduler_output.scheduled_new_reqs:
117
+ req_id = new_req_data.req_id
118
+ sampling_params = new_req_data.sampling_params
119
+
120
+ self.requests[req_id] = CachedRequestState(
121
+ req_id=req_id,
122
+ prompt_token_ids=new_req_data.prompt_token_ids,
123
+ mm_features=new_req_data.mm_features,
124
+ sampling_params=sampling_params,
125
+ pooling_params=None,
126
+ generator=None,
127
+ block_ids=new_req_data.block_ids,
128
+ num_computed_tokens=new_req_data.num_computed_tokens,
129
+ output_token_ids=[],
130
+ lora_request=new_req_data.lora_request,
131
+ )
132
+
133
+ req_ids_to_add.append(req_id)
134
+
135
+ # Only relevant for models using M-RoPE (e.g, Qwen2-VL)
136
+ if self.uses_mrope:
137
+ image_grid_thw = []
138
+ video_grid_thw = []
139
+ second_per_grid_ts = []
140
+ audio_feature_lengths = []
141
+ use_audio_in_video = False
142
+ for mm_feature in self.requests[req_id].mm_features:
143
+ item = mm_feature.data
144
+ if item is None:
145
+ continue
146
+ mm_input = item.get_data()
147
+ if mm_input.get("image_grid_thw") is not None:
148
+ image_grid_thw.append(
149
+ mm_input["image_grid_thw"].tolist())
150
+ if mm_input.get("video_grid_thw") is not None:
151
+ video_grid_thw.append(
152
+ mm_input["video_grid_thw"].tolist())
153
+ if mm_input.get("second_per_grid_ts") is not None:
154
+ second_per_grid_ts.append(
155
+ mm_input["second_per_grid_ts"])
156
+ if mm_input.get("audio_feature_lengths") is not None:
157
+ audio_feature_lengths.append(
158
+ mm_input["audio_feature_lengths"])
159
+ if mm_input.get("use_audio_in_video") is True:
160
+ use_audio_in_video = True
161
+
162
+ hf_config = self.model_config.hf_config
163
+
164
+ self.requests[req_id].mrope_positions, self.requests[
165
+ req_id].mrope_position_delta = get_mrope_input_positions_fn(
166
+ self.requests[req_id].prompt_token_ids,
167
+ hf_config=hf_config,
168
+ image_grid_thw=image_grid_thw,
169
+ video_grid_thw=video_grid_thw,
170
+ second_per_grid_ts=second_per_grid_ts,
171
+ audio_feature_lengths=audio_feature_lengths,
172
+ use_audio_in_video=use_audio_in_video,
173
+ )
174
+
175
+ # Update the states of the running/resumed requests.
176
+ req_data = scheduler_output.scheduled_cached_reqs
177
+ for i, req_id in enumerate(req_data.req_ids):
178
+ req_state = self.requests[req_id]
179
+ num_computed_tokens = req_data.num_computed_tokens[i]
180
+ new_block_ids = req_data.new_block_ids[i]
181
+ resumed_from_preemption = req_data.resumed_from_preemption[i]
182
+
183
+ # Update the cached states.
184
+ req_state.num_computed_tokens = num_computed_tokens
185
+ if not resumed_from_preemption:
186
+ if new_block_ids is not None:
187
+ # Append the new blocks to the existing block IDs.
188
+ for block_ids, new_ids in zip(req_state.block_ids,
189
+ new_block_ids):
190
+ block_ids.extend(new_ids)
191
+ else:
192
+ assert new_block_ids is not None
193
+ # The request is resumed from preemption.
194
+ # Replace the existing block IDs with the new ones.
195
+ req_state.block_ids = new_block_ids
196
+
197
+ req_index = self.input_batch.req_id_to_index.get(req_id)
198
+ if req_index is None:
199
+ # The request is not in the persistent batch.
200
+ # The request was either preempted and resumed later, or was not
201
+ # scheduled in the previous step and needs to be added again.
202
+ req_ids_to_add.append(req_id)
203
+ continue
204
+
205
+ # Update the persistent batch.
206
+ self.input_batch.num_computed_tokens_cpu[
207
+ req_index] = num_computed_tokens
208
+ if new_block_ids is not None:
209
+ self.input_batch.block_table.append_row(
210
+ new_block_ids, req_index)
211
+
212
+ # Add spec_token_ids to token_ids_cpu.
213
+ spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
214
+ req_id, ())
215
+ if spec_token_ids:
216
+ num_spec_tokens = len(spec_token_ids)
217
+ start_index = self.input_batch.num_tokens_no_spec[req_index]
218
+ end_token_index = start_index + num_spec_tokens
219
+ self.input_batch.token_ids_cpu[
220
+ req_index, start_index:end_token_index] = spec_token_ids
221
+ # NOTE(woosuk): `num_tokens` here may include spec tokens.
222
+ self.input_batch.num_tokens[req_index] += num_spec_tokens
223
+
224
+ # Add the new or resumed requests to the persistent batch.
225
+ # The smaller empty indices are filled first.
226
+ removed_req_indices = sorted(removed_req_indices, reverse=True)
227
+ for req_id in req_ids_to_add:
228
+ req_state = self.requests[req_id]
229
+ if removed_req_indices:
230
+ # Fill the empty index.
231
+ req_index = removed_req_indices.pop()
232
+ else:
233
+ # Append to the end.
234
+ req_index = None
235
+ self.input_batch.add_request(req_state, req_index)
236
+
237
+ # Condense the batched states if there are empty indices.
238
+ if removed_req_indices:
239
+ self.input_batch.condense(removed_req_indices)
240
+
241
+ batch_changed = len(unscheduled_req_ids) > 0 or len(req_ids_to_add) > 0
242
+ # TODO(jevinjiang): I assume we do not need to set batch_changed to true if just swapping requests.
243
+ self._reorder_batch(scheduler_output)
244
+ return batch_changed