sglang 0.4.2__py3-none-any.whl → 0.4.2.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -166,6 +166,12 @@ def _fwd_kernel(
166
166
  def context_attention_fwd(
167
167
  q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
168
168
  ):
169
+ """
170
+ q, k, v: [b * s, head, head_dim]
171
+ b_start_loc: [b]
172
+ b_seq_len: [b]
173
+ out: [b * s, head, head_dim]
174
+ """
169
175
  if is_cuda_available and CUDA_CAPABILITY[0] > 8:
170
176
  BLOCK = 128
171
177
  else:
@@ -4,6 +4,7 @@ from typing import Optional
4
4
 
5
5
  import torch
6
6
  import torch.nn as nn
7
+ import torch.nn.functional as F
7
8
  from einops import rearrange, repeat
8
9
 
9
10
  from sglang.srt.distributed import parallel_state
@@ -63,7 +64,20 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.T
63
64
 
64
65
 
65
66
  class VisionAttention(nn.Module):
66
- """Multi-headed attention without any cache, mostly used for ViT."""
67
+ r"""
68
+ Multi-headed attention without any cache, mostly used for ViT.
69
+
70
+
71
+ Args:
72
+ use_qkv_parallel (bool, optional): If True, use QKV-parallel attention.
73
+ use_context_forward (bool, default to True):
74
+ if ``True``, a flash_attn style attention will be applied
75
+ Otherwise, a full-sequence attention will be applied.
76
+ use_full_precision_softmax (bool, default to False):
77
+ if ``True``, the softmax will be performed in full-precision
78
+ Otherwise, it will be performed in half-precision
79
+
80
+ """
67
81
 
68
82
  def __init__(
69
83
  self,
@@ -72,25 +86,39 @@ class VisionAttention(nn.Module):
72
86
  projection_size: int,
73
87
  use_qkv_parallel: bool,
74
88
  quant_config: Optional[QuantizationConfig] = None,
89
+ dropout: float = 0.0,
90
+ use_context_forward: bool = True,
91
+ use_full_precision_softmax: bool = False,
92
+ flatten_batch: bool = False,
75
93
  prefix: str = "",
76
94
  ):
77
95
  super().__init__()
96
+ self.use_context_forward = use_context_forward
78
97
  world_size = parallel_state.get_tensor_model_parallel_world_size()
79
-
98
+ self.dropout = dropout
99
+ self.head_size = embed_dim // num_heads
80
100
  self.hidden_size_per_attention_head = dist_utils.divide(
81
101
  projection_size, num_heads
82
102
  )
83
103
  self.num_attention_heads_per_partition = dist_utils.divide(
84
104
  num_heads, world_size
85
105
  )
86
- # self.tp_size = get_tensor_model_parallel_world_size()
87
- # num_heads = self.num_heads_per_partition
106
+
107
+ if self.use_context_forward:
108
+ self.qkv_backend = VisionTritonAttention()
109
+ else:
110
+ self.qkv_backend = VisionSdpaAttention(
111
+ head_size=self.head_size,
112
+ dropout=dropout,
113
+ flatten_batch=flatten_batch,
114
+ use_full_precision_softmax=use_full_precision_softmax,
115
+ )
116
+
88
117
  self.use_qkv_parallel = use_qkv_parallel
89
118
  if use_qkv_parallel:
90
- self.head_dim = embed_dim // num_heads
91
119
  self.qkv_proj = QKVParallelLinear(
92
120
  hidden_size=embed_dim,
93
- head_size=self.head_dim,
121
+ head_size=self.head_size,
94
122
  total_num_heads=num_heads,
95
123
  quant_config=quant_config,
96
124
  prefix=f"{prefix}.qkv_proj",
@@ -114,12 +142,15 @@ class VisionAttention(nn.Module):
114
142
  x: torch.Tensor,
115
143
  cu_seqlens: Optional[torch.Tensor] = None,
116
144
  rotary_pos_emb: torch.Tensor = None,
145
+ attention_mask: Optional[torch.Tensor] = None,
117
146
  ) -> torch.Tensor:
147
+ r"""
148
+ Args:
149
+ x: [b, s, embed_dim]
150
+ cu_seqlens: [b]
151
+ Returns:
152
+ [s, b, num_heads * head]
118
153
  """
119
- Input shape: [b, s, embed_dim]
120
- Output shape: [s, b, num_heads * head_size]
121
- """
122
-
123
154
  bsz, s, _ = x.shape
124
155
  if self.use_qkv_parallel:
125
156
  # [b, s, embed_dim] --> [b, s, embed_dim]
@@ -136,19 +167,19 @@ class VisionAttention(nn.Module):
136
167
  else:
137
168
  # [b, s, embed_dim] --> [s, b, embed_dim]
138
169
  x = rearrange(x, "b s ... -> s b ...")
139
- # [s, b, embed_dim] --> [s, b, head * 3 * head_dim]
170
+ # [s, b, embed_dim] --> [s, b, head * 3 * head_size]
140
171
  qkv, _ = self.qkv_proj(x)
141
- # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
172
+ # [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
142
173
  new_x_shape = qkv.size()[:-1] + (
143
174
  self.num_attention_heads_per_partition,
144
175
  3 * self.hidden_size_per_attention_head,
145
176
  )
146
177
  qkv = qkv.view(*new_x_shape)
147
178
 
148
- # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
179
+ # [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
149
180
  q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
150
181
 
151
- # [s, b, head, head_dim] --> [b, s, head, head_dim]
182
+ # [s, b, head, head_size] --> [b, s, head, head_size]
152
183
  q, k, v = [
153
184
  rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
154
185
  ]
@@ -160,45 +191,217 @@ class VisionAttention(nn.Module):
160
191
  if self.use_qkv_parallel:
161
192
  pass
162
193
  else:
163
- # [b, s, head, head_dim] --> [b * s, head, head_dim]
194
+ # [b, s, head, head_size] --> [b * s, head, head_size]
164
195
  q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
165
196
 
166
- # [b * s, num_heads, head_size]
167
- output = torch.empty_like(q)
168
-
169
- seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda()
170
- max_seqlen = seq_lens.max().item()
171
-
172
- context_attention_fwd(
173
- q,
174
- k,
175
- v,
176
- output,
177
- cu_seqlens.cuda(),
178
- seq_lens,
179
- max_seqlen,
180
- is_causal=False,
181
- )
197
+ output = self.qkv_backend.forward(q, k, v, bsz, cu_seqlens, attention_mask)
182
198
 
183
199
  if self.use_qkv_parallel:
184
-
185
- # [b * s, head, head_dim] --> [b, s, head * head_dim]
200
+ # [b * s, h, head_size] --> [b, s, h * head_size]
186
201
  output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
187
202
 
188
- # [b, s, head, head_dim] --> [b, s, head, head_dim]
203
+ # [b, s, h * head_size] --> [b, s, h * head_size]
189
204
  output, _ = self.proj(output)
190
205
  else:
191
- # [b * s, head, head_dim] --> [b, s, head, head_dim]
192
- context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz)
193
-
194
- # [s, b, num_heads * head_size]
206
+ # [b * s, h, head_size] --> [s, b, h * head_size]
195
207
  context_layer = rearrange(
196
- context_layer, "b s h d -> s b (h d)"
208
+ output, "(b s) h d -> s b (h d)", b=bsz, s=s
197
209
  ).contiguous()
198
210
 
199
- # [s, b, num_heads * head_size] --> [s, b, num_heads * head_size]
211
+ # [s, b, h * head_size] --> [s, b, h * head_size]
200
212
  output, _ = self.proj(context_layer)
201
213
 
214
+ # [s, b, h * head_size] --> [b, s, h * head_size]
202
215
  output = output.view(bsz, s, -1)
203
216
 
204
217
  return output
218
+
219
+
220
+ class VisionSdpaAttention(nn.Module):
221
+ r"""
222
+ Scaled Dot Product Attention inner product
223
+
224
+ """
225
+
226
+ # TODO: Should it be released after used?
227
+ _mask_cache = {}
228
+
229
+ def __init__(
230
+ self,
231
+ head_size: int,
232
+ dropout: float = 0.0,
233
+ flatten_batch: bool = False,
234
+ use_full_precision_softmax: bool = False,
235
+ ):
236
+ super().__init__()
237
+ self.head_size = head_size
238
+ self.flatten_batch = flatten_batch
239
+ self.use_full_precision_softmax = use_full_precision_softmax
240
+ self.dropout = dropout
241
+
242
+ def generate_patch_attention_mask(
243
+ self,
244
+ s: int,
245
+ bsz: int,
246
+ device,
247
+ cu_seqlens: Optional[torch.Tensor],
248
+ flatten_batch: bool = False,
249
+ dtype=torch.bfloat16,
250
+ ) -> torch.Tensor:
251
+ r"""
252
+ Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
253
+
254
+ When `flatten_batch` is True:
255
+ - All sequences in the batch are flattened into a single dimension
256
+ - `s` represents the total number of tokens across all sequences in the batch
257
+ - Returns a unified mask of shape `(1, 1, s, s)`
258
+
259
+ When `flatten_batch` is False:
260
+ - Each sequence has its own attention mask
261
+ - `s` represents the maximum sequence length in the batch
262
+ - Returns separate masks of shape `(b, 1, s, s)`
263
+
264
+ Args:
265
+ flatten_batch: (bool):
266
+ If True, treats all sequences in the batch as a single flattened sequence
267
+ If False, generates separate masks for each sequence
268
+
269
+ Returns:
270
+ Tensor of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
271
+ """
272
+
273
+ cache_key = (s, bsz, flatten_batch, tuple(cu_seqlens.cpu().tolist()))
274
+
275
+ if cache_key in VisionSdpaAttention._mask_cache:
276
+ cached_mask = VisionSdpaAttention._mask_cache[cache_key]
277
+ # print(f"cache hit for key: {cache_key}")
278
+ return cached_mask.to(device=device, dtype=dtype)
279
+
280
+ if cu_seqlens is None:
281
+ raise ValueError("Internal Error: cu_seqlens cannot be None")
282
+
283
+ if flatten_batch:
284
+ mask = torch.zeros([1, s, s], device=device, dtype=torch.bool)
285
+ for i in range(1, len(cu_seqlens)):
286
+ start = cu_seqlens[i - 1]
287
+ end = cu_seqlens[i]
288
+ mask[
289
+ ...,
290
+ start:end,
291
+ start:end,
292
+ ] = True
293
+ else:
294
+ # [1, 1, 1, s]
295
+ row_indices = torch.arange(s, device=device).view(1, 1, 1, s)
296
+ # [1, 1, s, 1]
297
+ col_indices = torch.arange(s, device=device).view(1, 1, s, 1)
298
+ # [b, 1, 1, 1]
299
+ seq_lens = (
300
+ (cu_seqlens[1:] - cu_seqlens[:-1]).to(device=device).view(-1, 1, 1, 1)
301
+ )
302
+
303
+ mask = (row_indices < seq_lens) & (col_indices < seq_lens)
304
+
305
+ # Convert to attention mask format (False -> 0, True -> -inf)
306
+ mask = (~mask).to(dtype) * torch.finfo(dtype).min
307
+
308
+ VisionSdpaAttention._mask_cache[cache_key] = mask
309
+
310
+ return mask
311
+
312
+ def forward(
313
+ self,
314
+ q: torch.Tensor,
315
+ k: torch.Tensor,
316
+ v: torch.Tensor,
317
+ bsz: int,
318
+ cu_seqlens: Optional[torch.Tensor] = None,
319
+ attention_mask: Optional[torch.Tensor] = None,
320
+ ) -> torch.Tensor:
321
+ r"""
322
+ Args:
323
+ cu_seqlens: [b]
324
+ Returns:
325
+ [b * s, h, head_size]
326
+ """
327
+
328
+ s = q.shape[0] // bsz
329
+
330
+ # [b, 1, s, s]
331
+ if attention_mask is None:
332
+ attention_mask = self.generate_patch_attention_mask(
333
+ s, bsz, q.device, cu_seqlens, self.flatten_batch, q.dtype
334
+ )
335
+ q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
336
+ # [b, 1, s]
337
+ if self.use_full_precision_softmax:
338
+ scale = self.head_size**-0.5
339
+ k_transposed = rearrange(k, "b h s d -> b h d s")
340
+ attn_weights = torch.matmul(q, k_transposed) * scale
341
+ del k, k_transposed
342
+ attn_weights = attn_weights + attention_mask
343
+ del attention_mask
344
+ # full-precision
345
+ attn_weights = nn.functional.softmax(
346
+ attn_weights, dim=-1, dtype=torch.float32
347
+ ).to(q.dtype)
348
+ attn_weights = nn.functional.dropout(
349
+ attn_weights, p=self.dropout, training=False
350
+ )
351
+ output = torch.matmul(attn_weights, v)
352
+ del attn_weights, v
353
+ else:
354
+ # SDPA
355
+ # [b, h, s, head_size]
356
+ output = F.scaled_dot_product_attention(
357
+ q, k, v, attention_mask, dropout_p=self.dropout
358
+ )
359
+
360
+ # [b, h, s, head_size] --> [b * s, h, head_size]
361
+ output = rearrange(output, "b h s d -> (b s) h d")
362
+
363
+ return output
364
+
365
+
366
+ class VisionTritonAttention(nn.Module):
367
+ """
368
+ Triton-implemented attention without a causal mask
369
+ """
370
+
371
+ def __init__(
372
+ self,
373
+ ):
374
+ super().__init__()
375
+
376
+ def forward(
377
+ self,
378
+ q: torch.Tensor,
379
+ k: torch.Tensor,
380
+ v: torch.Tensor,
381
+ _bsz: int,
382
+ cu_seqlens: Optional[torch.Tensor],
383
+ **kwargs,
384
+ ) -> torch.Tensor:
385
+ r"""
386
+ Args:
387
+ cu_seqlens: [b]
388
+ Returns:
389
+ [b * s, h, head_size]
390
+ """
391
+
392
+ # [b * s, head, head_size]
393
+ output = torch.empty_like(q)
394
+ seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
395
+ max_seqlen = seq_lens.max().item()
396
+ context_attention_fwd(
397
+ q,
398
+ k,
399
+ v,
400
+ output,
401
+ cu_seqlens.cuda(),
402
+ seq_lens.cuda(),
403
+ max_seqlen,
404
+ is_causal=False,
405
+ )
406
+
407
+ return output
@@ -290,6 +290,13 @@ class Fp8LinearMethod(LinearMethodBase):
290
290
  weight_scale, requires_grad=False
291
291
  )
292
292
  layer.input_scale = None
293
+ else:
294
+ layer.weight = torch.nn.Parameter(
295
+ layer.weight.data, requires_grad=False
296
+ )
297
+ layer.weight_scale_inv = torch.nn.Parameter(
298
+ layer.weight_scale_inv.data, requires_grad=False
299
+ )
293
300
  return
294
301
  layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
295
302
  # If checkpoint not serialized fp8, quantize the weights.
@@ -6,9 +6,15 @@ from typing import Any, Dict, List, Optional, Tuple, Union
6
6
 
7
7
  import torch
8
8
  import torch.nn as nn
9
+ from vllm import _custom_ops as ops
9
10
  from vllm.model_executor.custom_op import CustomOp
10
11
 
11
12
  from sglang.srt.layers.custom_op_util import register_custom_op
13
+ from sglang.srt.utils import is_cuda_available
14
+
15
+ _is_cuda_available = is_cuda_available()
16
+ if _is_cuda_available:
17
+ from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
12
18
 
13
19
 
14
20
  def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -75,7 +81,9 @@ class RotaryEmbedding(CustomOp):
75
81
  self.dtype = dtype
76
82
 
77
83
  cache = self._compute_cos_sin_cache()
78
- cache = cache.to(dtype)
84
+ # NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
85
+ if not _is_cuda_available:
86
+ cache = cache.to(dtype)
79
87
  self.cos_sin_cache: torch.Tensor
80
88
  self.register_buffer("cos_sin_cache", cache, persistent=False)
81
89
 
@@ -141,17 +149,25 @@ class RotaryEmbedding(CustomOp):
141
149
  key: torch.Tensor,
142
150
  offsets: Optional[torch.Tensor] = None,
143
151
  ) -> Tuple[torch.Tensor, torch.Tensor]:
144
- from vllm import _custom_ops as ops
145
-
146
- self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
147
- ops.rotary_embedding(
148
- positions,
149
- query,
150
- key,
151
- self.head_size,
152
- self.cos_sin_cache,
153
- self.is_neox_style,
154
- )
152
+ if _is_cuda_available:
153
+ apply_rope_with_cos_sin_cache_inplace(
154
+ positions=positions,
155
+ query=query,
156
+ key=key,
157
+ head_size=self.head_size,
158
+ cos_sin_cache=self.cos_sin_cache,
159
+ is_neox=self.is_neox_style,
160
+ )
161
+ else:
162
+ self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
163
+ ops.rotary_embedding(
164
+ positions,
165
+ query,
166
+ key,
167
+ self.head_size,
168
+ self.cos_sin_cache,
169
+ self.is_neox_style,
170
+ )
155
171
  return query, key
156
172
 
157
173
  def forward_xpu(
@@ -72,9 +72,11 @@ class Sampler(nn.Module):
72
72
  # NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
73
73
  # https://github.com/flashinfer-ai/flashinfer/issues/708
74
74
  # so we use the torch implementation.
75
+
76
+ # clamp to avoid -inf
75
77
  logprobs = torch.log(
76
78
  top_p_normalize_probs_torch(probs, sampling_info.top_ps)
77
- )
79
+ ).clamp(min=torch.finfo(probs.dtype).min)
78
80
 
79
81
  max_top_k_round, batch_size = 32, probs.shape[0]
80
82
  uniform_samples = torch.rand(
@@ -109,9 +111,10 @@ class Sampler(nn.Module):
109
111
  sampling_info.need_min_p_sampling,
110
112
  )
111
113
  if return_logprob:
114
+ # clamp to avoid -inf
112
115
  logprobs = torch.log(
113
116
  top_p_normalize_probs_torch(probs, sampling_info.top_ps)
114
- )
117
+ ).clamp(min=torch.finfo(probs.dtype).min)
115
118
  else:
116
119
  raise ValueError(
117
120
  f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
@@ -240,6 +240,7 @@ class MllamaImageProcessor(BaseImageProcessor):
240
240
  class MiniCPMVImageProcessor(BaseImageProcessor):
241
241
  def __init__(self, hf_config, server_args, _processor):
242
242
  super().__init__(hf_config, server_args, _processor)
243
+ self.IMAGE_TOKEN = "(<image>./</image>)"
243
244
 
244
245
  @staticmethod
245
246
  def _process_images_task(images, input_text):
@@ -271,7 +272,7 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
271
272
  async def process_images_async(
272
273
  self,
273
274
  image_data: List[Union[str, bytes]],
274
- input_text,
275
+ input_ids,
275
276
  request_obj,
276
277
  max_req_input_len,
277
278
  ):
@@ -282,28 +283,49 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
282
283
  image_data = [image_data]
283
284
 
284
285
  image_hashes, image_sizes = [], []
285
- raw_images = []
286
- IMAGE_TOKEN = "(<image>./</image>)"
286
+ all_frames = []
287
287
 
288
- # roughly calculate the max number of frames
289
- # TODO: the process should be applied to all the visual inputs
288
+ # roughly calculate the max number of frames under the max_req_input_len limit
290
289
  def calculate_max_num_frames() -> int:
291
290
  # Model-specific
292
291
  NUM_TOKEN_PER_FRAME = 330
293
292
 
294
- ret = (max_req_input_len - len(input_text)) // NUM_TOKEN_PER_FRAME
293
+ ret = (max_req_input_len - len(input_ids)) // NUM_TOKEN_PER_FRAME
295
294
  return min(ret, 100)
296
295
 
297
- # if cuda OOM set a smaller number
298
296
  MAX_NUM_FRAMES = calculate_max_num_frames()
299
- print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
300
297
 
301
- def encode_video(video_path):
298
+ # print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
299
+
300
+ def get_estimated_frames_list():
301
+ """
302
+ estimate the total frame count from all visual input
303
+ """
304
+ # Before processing inputs
305
+ estimated_frames_list = []
306
+ for image in image_data:
307
+ if isinstance(image, str) and image.startswith("video:"):
308
+ path = image[len("video:") :]
309
+ # Estimate frames for the video
310
+ vr = VideoReader(path, ctx=cpu(0))
311
+ num_frames = len(vr)
312
+ else:
313
+ # For images, each contributes one frame
314
+ num_frames = 1
315
+ estimated_frames_list.append(num_frames)
316
+
317
+ return estimated_frames_list
318
+
319
+ estimated_frames_list = get_estimated_frames_list()
320
+ total_frame_count = sum(estimated_frames_list)
321
+ scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
322
+
323
+ def encode_video(video_path, frame_count_limit=None):
302
324
  if not os.path.exists(video_path):
303
325
  logger.error(f"Video {video_path} does not exist")
304
326
  return []
305
327
 
306
- if MAX_NUM_FRAMES == 0:
328
+ if frame_count_limit == 0:
307
329
  return []
308
330
 
309
331
  def uniform_sample(l, n):
@@ -314,45 +336,63 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
314
336
  vr = VideoReader(video_path, ctx=cpu(0))
315
337
  sample_fps = round(vr.get_avg_fps() / 1) # FPS
316
338
  frame_idx = [i for i in range(0, len(vr), sample_fps)]
317
- if len(frame_idx) > MAX_NUM_FRAMES:
318
- frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES)
339
+ if frame_count_limit is not None and len(frame_idx) > frame_count_limit:
340
+ frame_idx = uniform_sample(frame_idx, frame_count_limit)
319
341
  frames = vr.get_batch(frame_idx).asnumpy()
320
342
  frames = [Image.fromarray(v.astype("uint8")) for v in frames]
321
343
  return frames
322
344
 
323
- if isinstance(input_text, list):
324
- assert len(input_text) and isinstance(input_text[0], int)
325
- input_text = self._processor.tokenizer.decode(input_text)
326
-
345
+ if isinstance(input_ids, list):
346
+ assert len(input_ids) and isinstance(input_ids[0], int)
347
+ input_text = self._processor.tokenizer.decode(input_ids)
348
+ else:
349
+ input_text = input_ids
327
350
  # MiniCPMV requires each frame of video as a single image token
328
- text_parts = input_text.split(IMAGE_TOKEN)
351
+ text_parts = input_text.split(self.IMAGE_TOKEN)
329
352
  new_text_parts = []
330
353
 
331
- for image_index, image in enumerate(image_data):
332
- try:
333
- if isinstance(image, str) and image.startswith("video:"):
334
- path = image[len("video:") :]
335
- frames = encode_video(path)
336
- else:
337
- raw_image, size = load_image(image)
338
- frames = [raw_image]
339
- if len(frames) == 0:
340
- continue
341
- except FileNotFoundError as e:
342
- print(e)
343
- return None
344
-
345
- image_sizes += frames[0].size * len(frames)
346
- image_hashes += [hash(image)] * len(frames)
347
- raw_images += frames
354
+ # Process each input with allocated frames
355
+ for image_index, (image, estimated_frames) in enumerate(
356
+ zip(image_data, estimated_frames_list)
357
+ ):
358
+ if len(all_frames) >= MAX_NUM_FRAMES:
359
+ frames_to_process = 0
360
+ else:
361
+ frames_to_process = max(1, int(estimated_frames * scaling_factor))
362
+
363
+ if frames_to_process == 0:
364
+ frames = []
365
+ else:
366
+ try:
367
+ if isinstance(image, str) and image.startswith("video:"):
368
+ path = image[len("video:") :]
369
+ frames = encode_video(path, frame_count_limit=frames_to_process)
370
+ else:
371
+ raw_image, _size = load_image(image)
372
+ frames = [raw_image]
373
+ if len(frames) == 0:
374
+ continue
375
+ except FileNotFoundError as e:
376
+ print(e)
377
+ return None
378
+ image_sizes += frames[0].size * len(frames)
379
+ image_hashes += [hash(image)] * len(frames)
380
+ all_frames += frames
381
+
382
+ assert frames_to_process == len(frames)
383
+
348
384
  new_text_parts.append(text_parts[image_index])
349
- new_text_parts.append(IMAGE_TOKEN * len(frames))
385
+
386
+ if frames_to_process != 0:
387
+ new_text_parts.append(self.IMAGE_TOKEN * len(frames))
350
388
 
351
389
  new_text_parts.append(text_parts[-1])
390
+
352
391
  input_text = "".join(new_text_parts)
353
- if len(raw_images) == 0:
392
+
393
+ if len(all_frames) == 0:
354
394
  return None
355
- res = await self._process_images(images=raw_images, input_text=input_text)
395
+ res = await self._process_images(images=all_frames, input_text=input_text)
356
396
  pixel_values = res["pixel_values"]
357
397
  tgt_sizes = res["tgt_sizes"]
358
398
  input_ids = res["input_ids"]
@@ -364,7 +404,6 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
364
404
  if tokenizer.slice_start_id:
365
405
  slice_start_id = [tokenizer.slice_start_id]
366
406
  slice_end_id = [tokenizer.slice_end_id]
367
-
368
407
  return {
369
408
  "input_ids": input_ids.flatten().tolist(),
370
409
  "pixel_values": pixel_values,