sglang 0.4.2__py3-none-any.whl → 0.4.2.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +6 -0
- sglang/srt/layers/attention/vision.py +243 -40
- sglang/srt/layers/quantization/fp8.py +7 -0
- sglang/srt/layers/rotary_embedding.py +28 -12
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/managers/image_processor.py +77 -38
- sglang/srt/managers/scheduler.py +17 -3
- sglang/srt/mem_cache/base_prefix_cache.py +4 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +30 -1
- sglang/srt/models/minicpmv.py +129 -76
- sglang/srt/models/mllama.py +16 -56
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_vl.py +18 -8
- sglang/srt/server_args.py +6 -0
- sglang/srt/utils.py +0 -2
- sglang/utils.py +42 -0
- sglang/version.py +1 -1
- {sglang-0.4.2.dist-info → sglang-0.4.2.post1.dist-info}/METADATA +3 -3
- {sglang-0.4.2.dist-info → sglang-0.4.2.post1.dist-info}/RECORD +23 -23
- {sglang-0.4.2.dist-info → sglang-0.4.2.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.2.dist-info → sglang-0.4.2.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.2.dist-info → sglang-0.4.2.post1.dist-info}/top_level.txt +0 -0
@@ -166,6 +166,12 @@ def _fwd_kernel(
|
|
166
166
|
def context_attention_fwd(
|
167
167
|
q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
|
168
168
|
):
|
169
|
+
"""
|
170
|
+
q, k, v: [b * s, head, head_dim]
|
171
|
+
b_start_loc: [b]
|
172
|
+
b_seq_len: [b]
|
173
|
+
out: [b * s, head, head_dim]
|
174
|
+
"""
|
169
175
|
if is_cuda_available and CUDA_CAPABILITY[0] > 8:
|
170
176
|
BLOCK = 128
|
171
177
|
else:
|
@@ -4,6 +4,7 @@ from typing import Optional
|
|
4
4
|
|
5
5
|
import torch
|
6
6
|
import torch.nn as nn
|
7
|
+
import torch.nn.functional as F
|
7
8
|
from einops import rearrange, repeat
|
8
9
|
|
9
10
|
from sglang.srt.distributed import parallel_state
|
@@ -63,7 +64,20 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.T
|
|
63
64
|
|
64
65
|
|
65
66
|
class VisionAttention(nn.Module):
|
66
|
-
"""
|
67
|
+
r"""
|
68
|
+
Multi-headed attention without any cache, mostly used for ViT.
|
69
|
+
|
70
|
+
|
71
|
+
Args:
|
72
|
+
use_qkv_parallel (bool, optional): If True, use QKV-parallel attention.
|
73
|
+
use_context_forward (bool, default to True):
|
74
|
+
if ``True``, a flash_attn style attention will be applied
|
75
|
+
Otherwise, a full-sequence attention will be applied.
|
76
|
+
use_full_precision_softmax (bool, default to False):
|
77
|
+
if ``True``, the softmax will be performed in full-precision
|
78
|
+
Otherwise, it will be performed in half-precision
|
79
|
+
|
80
|
+
"""
|
67
81
|
|
68
82
|
def __init__(
|
69
83
|
self,
|
@@ -72,25 +86,39 @@ class VisionAttention(nn.Module):
|
|
72
86
|
projection_size: int,
|
73
87
|
use_qkv_parallel: bool,
|
74
88
|
quant_config: Optional[QuantizationConfig] = None,
|
89
|
+
dropout: float = 0.0,
|
90
|
+
use_context_forward: bool = True,
|
91
|
+
use_full_precision_softmax: bool = False,
|
92
|
+
flatten_batch: bool = False,
|
75
93
|
prefix: str = "",
|
76
94
|
):
|
77
95
|
super().__init__()
|
96
|
+
self.use_context_forward = use_context_forward
|
78
97
|
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
79
|
-
|
98
|
+
self.dropout = dropout
|
99
|
+
self.head_size = embed_dim // num_heads
|
80
100
|
self.hidden_size_per_attention_head = dist_utils.divide(
|
81
101
|
projection_size, num_heads
|
82
102
|
)
|
83
103
|
self.num_attention_heads_per_partition = dist_utils.divide(
|
84
104
|
num_heads, world_size
|
85
105
|
)
|
86
|
-
|
87
|
-
|
106
|
+
|
107
|
+
if self.use_context_forward:
|
108
|
+
self.qkv_backend = VisionTritonAttention()
|
109
|
+
else:
|
110
|
+
self.qkv_backend = VisionSdpaAttention(
|
111
|
+
head_size=self.head_size,
|
112
|
+
dropout=dropout,
|
113
|
+
flatten_batch=flatten_batch,
|
114
|
+
use_full_precision_softmax=use_full_precision_softmax,
|
115
|
+
)
|
116
|
+
|
88
117
|
self.use_qkv_parallel = use_qkv_parallel
|
89
118
|
if use_qkv_parallel:
|
90
|
-
self.head_dim = embed_dim // num_heads
|
91
119
|
self.qkv_proj = QKVParallelLinear(
|
92
120
|
hidden_size=embed_dim,
|
93
|
-
head_size=self.
|
121
|
+
head_size=self.head_size,
|
94
122
|
total_num_heads=num_heads,
|
95
123
|
quant_config=quant_config,
|
96
124
|
prefix=f"{prefix}.qkv_proj",
|
@@ -114,12 +142,15 @@ class VisionAttention(nn.Module):
|
|
114
142
|
x: torch.Tensor,
|
115
143
|
cu_seqlens: Optional[torch.Tensor] = None,
|
116
144
|
rotary_pos_emb: torch.Tensor = None,
|
145
|
+
attention_mask: Optional[torch.Tensor] = None,
|
117
146
|
) -> torch.Tensor:
|
147
|
+
r"""
|
148
|
+
Args:
|
149
|
+
x: [b, s, embed_dim]
|
150
|
+
cu_seqlens: [b]
|
151
|
+
Returns:
|
152
|
+
[s, b, num_heads * head]
|
118
153
|
"""
|
119
|
-
Input shape: [b, s, embed_dim]
|
120
|
-
Output shape: [s, b, num_heads * head_size]
|
121
|
-
"""
|
122
|
-
|
123
154
|
bsz, s, _ = x.shape
|
124
155
|
if self.use_qkv_parallel:
|
125
156
|
# [b, s, embed_dim] --> [b, s, embed_dim]
|
@@ -136,19 +167,19 @@ class VisionAttention(nn.Module):
|
|
136
167
|
else:
|
137
168
|
# [b, s, embed_dim] --> [s, b, embed_dim]
|
138
169
|
x = rearrange(x, "b s ... -> s b ...")
|
139
|
-
# [s, b, embed_dim] --> [s, b, head * 3 *
|
170
|
+
# [s, b, embed_dim] --> [s, b, head * 3 * head_size]
|
140
171
|
qkv, _ = self.qkv_proj(x)
|
141
|
-
# [s, b, head * 3 *
|
172
|
+
# [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
|
142
173
|
new_x_shape = qkv.size()[:-1] + (
|
143
174
|
self.num_attention_heads_per_partition,
|
144
175
|
3 * self.hidden_size_per_attention_head,
|
145
176
|
)
|
146
177
|
qkv = qkv.view(*new_x_shape)
|
147
178
|
|
148
|
-
# [s, b, head, 3 *
|
179
|
+
# [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
|
149
180
|
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
|
150
181
|
|
151
|
-
# [s, b, head,
|
182
|
+
# [s, b, head, head_size] --> [b, s, head, head_size]
|
152
183
|
q, k, v = [
|
153
184
|
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
154
185
|
]
|
@@ -160,45 +191,217 @@ class VisionAttention(nn.Module):
|
|
160
191
|
if self.use_qkv_parallel:
|
161
192
|
pass
|
162
193
|
else:
|
163
|
-
# [b, s, head,
|
194
|
+
# [b, s, head, head_size] --> [b * s, head, head_size]
|
164
195
|
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
165
196
|
|
166
|
-
|
167
|
-
output = torch.empty_like(q)
|
168
|
-
|
169
|
-
seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda()
|
170
|
-
max_seqlen = seq_lens.max().item()
|
171
|
-
|
172
|
-
context_attention_fwd(
|
173
|
-
q,
|
174
|
-
k,
|
175
|
-
v,
|
176
|
-
output,
|
177
|
-
cu_seqlens.cuda(),
|
178
|
-
seq_lens,
|
179
|
-
max_seqlen,
|
180
|
-
is_causal=False,
|
181
|
-
)
|
197
|
+
output = self.qkv_backend.forward(q, k, v, bsz, cu_seqlens, attention_mask)
|
182
198
|
|
183
199
|
if self.use_qkv_parallel:
|
184
|
-
|
185
|
-
# [b * s, head, head_dim] --> [b, s, head * head_dim]
|
200
|
+
# [b * s, h, head_size] --> [b, s, h * head_size]
|
186
201
|
output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
|
187
202
|
|
188
|
-
# [b, s,
|
203
|
+
# [b, s, h * head_size] --> [b, s, h * head_size]
|
189
204
|
output, _ = self.proj(output)
|
190
205
|
else:
|
191
|
-
# [b * s,
|
192
|
-
context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
193
|
-
|
194
|
-
# [s, b, num_heads * head_size]
|
206
|
+
# [b * s, h, head_size] --> [s, b, h * head_size]
|
195
207
|
context_layer = rearrange(
|
196
|
-
|
208
|
+
output, "(b s) h d -> s b (h d)", b=bsz, s=s
|
197
209
|
).contiguous()
|
198
210
|
|
199
|
-
# [s, b,
|
211
|
+
# [s, b, h * head_size] --> [s, b, h * head_size]
|
200
212
|
output, _ = self.proj(context_layer)
|
201
213
|
|
214
|
+
# [s, b, h * head_size] --> [b, s, h * head_size]
|
202
215
|
output = output.view(bsz, s, -1)
|
203
216
|
|
204
217
|
return output
|
218
|
+
|
219
|
+
|
220
|
+
class VisionSdpaAttention(nn.Module):
|
221
|
+
r"""
|
222
|
+
Scaled Dot Product Attention inner product
|
223
|
+
|
224
|
+
"""
|
225
|
+
|
226
|
+
# TODO: Should it be released after used?
|
227
|
+
_mask_cache = {}
|
228
|
+
|
229
|
+
def __init__(
|
230
|
+
self,
|
231
|
+
head_size: int,
|
232
|
+
dropout: float = 0.0,
|
233
|
+
flatten_batch: bool = False,
|
234
|
+
use_full_precision_softmax: bool = False,
|
235
|
+
):
|
236
|
+
super().__init__()
|
237
|
+
self.head_size = head_size
|
238
|
+
self.flatten_batch = flatten_batch
|
239
|
+
self.use_full_precision_softmax = use_full_precision_softmax
|
240
|
+
self.dropout = dropout
|
241
|
+
|
242
|
+
def generate_patch_attention_mask(
|
243
|
+
self,
|
244
|
+
s: int,
|
245
|
+
bsz: int,
|
246
|
+
device,
|
247
|
+
cu_seqlens: Optional[torch.Tensor],
|
248
|
+
flatten_batch: bool = False,
|
249
|
+
dtype=torch.bfloat16,
|
250
|
+
) -> torch.Tensor:
|
251
|
+
r"""
|
252
|
+
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
|
253
|
+
|
254
|
+
When `flatten_batch` is True:
|
255
|
+
- All sequences in the batch are flattened into a single dimension
|
256
|
+
- `s` represents the total number of tokens across all sequences in the batch
|
257
|
+
- Returns a unified mask of shape `(1, 1, s, s)`
|
258
|
+
|
259
|
+
When `flatten_batch` is False:
|
260
|
+
- Each sequence has its own attention mask
|
261
|
+
- `s` represents the maximum sequence length in the batch
|
262
|
+
- Returns separate masks of shape `(b, 1, s, s)`
|
263
|
+
|
264
|
+
Args:
|
265
|
+
flatten_batch: (bool):
|
266
|
+
If True, treats all sequences in the batch as a single flattened sequence
|
267
|
+
If False, generates separate masks for each sequence
|
268
|
+
|
269
|
+
Returns:
|
270
|
+
Tensor of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
|
271
|
+
"""
|
272
|
+
|
273
|
+
cache_key = (s, bsz, flatten_batch, tuple(cu_seqlens.cpu().tolist()))
|
274
|
+
|
275
|
+
if cache_key in VisionSdpaAttention._mask_cache:
|
276
|
+
cached_mask = VisionSdpaAttention._mask_cache[cache_key]
|
277
|
+
# print(f"cache hit for key: {cache_key}")
|
278
|
+
return cached_mask.to(device=device, dtype=dtype)
|
279
|
+
|
280
|
+
if cu_seqlens is None:
|
281
|
+
raise ValueError("Internal Error: cu_seqlens cannot be None")
|
282
|
+
|
283
|
+
if flatten_batch:
|
284
|
+
mask = torch.zeros([1, s, s], device=device, dtype=torch.bool)
|
285
|
+
for i in range(1, len(cu_seqlens)):
|
286
|
+
start = cu_seqlens[i - 1]
|
287
|
+
end = cu_seqlens[i]
|
288
|
+
mask[
|
289
|
+
...,
|
290
|
+
start:end,
|
291
|
+
start:end,
|
292
|
+
] = True
|
293
|
+
else:
|
294
|
+
# [1, 1, 1, s]
|
295
|
+
row_indices = torch.arange(s, device=device).view(1, 1, 1, s)
|
296
|
+
# [1, 1, s, 1]
|
297
|
+
col_indices = torch.arange(s, device=device).view(1, 1, s, 1)
|
298
|
+
# [b, 1, 1, 1]
|
299
|
+
seq_lens = (
|
300
|
+
(cu_seqlens[1:] - cu_seqlens[:-1]).to(device=device).view(-1, 1, 1, 1)
|
301
|
+
)
|
302
|
+
|
303
|
+
mask = (row_indices < seq_lens) & (col_indices < seq_lens)
|
304
|
+
|
305
|
+
# Convert to attention mask format (False -> 0, True -> -inf)
|
306
|
+
mask = (~mask).to(dtype) * torch.finfo(dtype).min
|
307
|
+
|
308
|
+
VisionSdpaAttention._mask_cache[cache_key] = mask
|
309
|
+
|
310
|
+
return mask
|
311
|
+
|
312
|
+
def forward(
|
313
|
+
self,
|
314
|
+
q: torch.Tensor,
|
315
|
+
k: torch.Tensor,
|
316
|
+
v: torch.Tensor,
|
317
|
+
bsz: int,
|
318
|
+
cu_seqlens: Optional[torch.Tensor] = None,
|
319
|
+
attention_mask: Optional[torch.Tensor] = None,
|
320
|
+
) -> torch.Tensor:
|
321
|
+
r"""
|
322
|
+
Args:
|
323
|
+
cu_seqlens: [b]
|
324
|
+
Returns:
|
325
|
+
[b * s, h, head_size]
|
326
|
+
"""
|
327
|
+
|
328
|
+
s = q.shape[0] // bsz
|
329
|
+
|
330
|
+
# [b, 1, s, s]
|
331
|
+
if attention_mask is None:
|
332
|
+
attention_mask = self.generate_patch_attention_mask(
|
333
|
+
s, bsz, q.device, cu_seqlens, self.flatten_batch, q.dtype
|
334
|
+
)
|
335
|
+
q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
|
336
|
+
# [b, 1, s]
|
337
|
+
if self.use_full_precision_softmax:
|
338
|
+
scale = self.head_size**-0.5
|
339
|
+
k_transposed = rearrange(k, "b h s d -> b h d s")
|
340
|
+
attn_weights = torch.matmul(q, k_transposed) * scale
|
341
|
+
del k, k_transposed
|
342
|
+
attn_weights = attn_weights + attention_mask
|
343
|
+
del attention_mask
|
344
|
+
# full-precision
|
345
|
+
attn_weights = nn.functional.softmax(
|
346
|
+
attn_weights, dim=-1, dtype=torch.float32
|
347
|
+
).to(q.dtype)
|
348
|
+
attn_weights = nn.functional.dropout(
|
349
|
+
attn_weights, p=self.dropout, training=False
|
350
|
+
)
|
351
|
+
output = torch.matmul(attn_weights, v)
|
352
|
+
del attn_weights, v
|
353
|
+
else:
|
354
|
+
# SDPA
|
355
|
+
# [b, h, s, head_size]
|
356
|
+
output = F.scaled_dot_product_attention(
|
357
|
+
q, k, v, attention_mask, dropout_p=self.dropout
|
358
|
+
)
|
359
|
+
|
360
|
+
# [b, h, s, head_size] --> [b * s, h, head_size]
|
361
|
+
output = rearrange(output, "b h s d -> (b s) h d")
|
362
|
+
|
363
|
+
return output
|
364
|
+
|
365
|
+
|
366
|
+
class VisionTritonAttention(nn.Module):
|
367
|
+
"""
|
368
|
+
Triton-implemented attention without a causal mask
|
369
|
+
"""
|
370
|
+
|
371
|
+
def __init__(
|
372
|
+
self,
|
373
|
+
):
|
374
|
+
super().__init__()
|
375
|
+
|
376
|
+
def forward(
|
377
|
+
self,
|
378
|
+
q: torch.Tensor,
|
379
|
+
k: torch.Tensor,
|
380
|
+
v: torch.Tensor,
|
381
|
+
_bsz: int,
|
382
|
+
cu_seqlens: Optional[torch.Tensor],
|
383
|
+
**kwargs,
|
384
|
+
) -> torch.Tensor:
|
385
|
+
r"""
|
386
|
+
Args:
|
387
|
+
cu_seqlens: [b]
|
388
|
+
Returns:
|
389
|
+
[b * s, h, head_size]
|
390
|
+
"""
|
391
|
+
|
392
|
+
# [b * s, head, head_size]
|
393
|
+
output = torch.empty_like(q)
|
394
|
+
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
395
|
+
max_seqlen = seq_lens.max().item()
|
396
|
+
context_attention_fwd(
|
397
|
+
q,
|
398
|
+
k,
|
399
|
+
v,
|
400
|
+
output,
|
401
|
+
cu_seqlens.cuda(),
|
402
|
+
seq_lens.cuda(),
|
403
|
+
max_seqlen,
|
404
|
+
is_causal=False,
|
405
|
+
)
|
406
|
+
|
407
|
+
return output
|
@@ -290,6 +290,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
290
290
|
weight_scale, requires_grad=False
|
291
291
|
)
|
292
292
|
layer.input_scale = None
|
293
|
+
else:
|
294
|
+
layer.weight = torch.nn.Parameter(
|
295
|
+
layer.weight.data, requires_grad=False
|
296
|
+
)
|
297
|
+
layer.weight_scale_inv = torch.nn.Parameter(
|
298
|
+
layer.weight_scale_inv.data, requires_grad=False
|
299
|
+
)
|
293
300
|
return
|
294
301
|
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
295
302
|
# If checkpoint not serialized fp8, quantize the weights.
|
@@ -6,9 +6,15 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
6
6
|
|
7
7
|
import torch
|
8
8
|
import torch.nn as nn
|
9
|
+
from vllm import _custom_ops as ops
|
9
10
|
from vllm.model_executor.custom_op import CustomOp
|
10
11
|
|
11
12
|
from sglang.srt.layers.custom_op_util import register_custom_op
|
13
|
+
from sglang.srt.utils import is_cuda_available
|
14
|
+
|
15
|
+
_is_cuda_available = is_cuda_available()
|
16
|
+
if _is_cuda_available:
|
17
|
+
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
12
18
|
|
13
19
|
|
14
20
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
@@ -75,7 +81,9 @@ class RotaryEmbedding(CustomOp):
|
|
75
81
|
self.dtype = dtype
|
76
82
|
|
77
83
|
cache = self._compute_cos_sin_cache()
|
78
|
-
cache
|
84
|
+
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
|
85
|
+
if not _is_cuda_available:
|
86
|
+
cache = cache.to(dtype)
|
79
87
|
self.cos_sin_cache: torch.Tensor
|
80
88
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
81
89
|
|
@@ -141,17 +149,25 @@ class RotaryEmbedding(CustomOp):
|
|
141
149
|
key: torch.Tensor,
|
142
150
|
offsets: Optional[torch.Tensor] = None,
|
143
151
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
152
|
+
if _is_cuda_available:
|
153
|
+
apply_rope_with_cos_sin_cache_inplace(
|
154
|
+
positions=positions,
|
155
|
+
query=query,
|
156
|
+
key=key,
|
157
|
+
head_size=self.head_size,
|
158
|
+
cos_sin_cache=self.cos_sin_cache,
|
159
|
+
is_neox=self.is_neox_style,
|
160
|
+
)
|
161
|
+
else:
|
162
|
+
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
163
|
+
ops.rotary_embedding(
|
164
|
+
positions,
|
165
|
+
query,
|
166
|
+
key,
|
167
|
+
self.head_size,
|
168
|
+
self.cos_sin_cache,
|
169
|
+
self.is_neox_style,
|
170
|
+
)
|
155
171
|
return query, key
|
156
172
|
|
157
173
|
def forward_xpu(
|
sglang/srt/layers/sampler.py
CHANGED
@@ -72,9 +72,11 @@ class Sampler(nn.Module):
|
|
72
72
|
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
|
73
73
|
# https://github.com/flashinfer-ai/flashinfer/issues/708
|
74
74
|
# so we use the torch implementation.
|
75
|
+
|
76
|
+
# clamp to avoid -inf
|
75
77
|
logprobs = torch.log(
|
76
78
|
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
77
|
-
)
|
79
|
+
).clamp(min=torch.finfo(probs.dtype).min)
|
78
80
|
|
79
81
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
80
82
|
uniform_samples = torch.rand(
|
@@ -109,9 +111,10 @@ class Sampler(nn.Module):
|
|
109
111
|
sampling_info.need_min_p_sampling,
|
110
112
|
)
|
111
113
|
if return_logprob:
|
114
|
+
# clamp to avoid -inf
|
112
115
|
logprobs = torch.log(
|
113
116
|
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
|
114
|
-
)
|
117
|
+
).clamp(min=torch.finfo(probs.dtype).min)
|
115
118
|
else:
|
116
119
|
raise ValueError(
|
117
120
|
f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
|
@@ -240,6 +240,7 @@ class MllamaImageProcessor(BaseImageProcessor):
|
|
240
240
|
class MiniCPMVImageProcessor(BaseImageProcessor):
|
241
241
|
def __init__(self, hf_config, server_args, _processor):
|
242
242
|
super().__init__(hf_config, server_args, _processor)
|
243
|
+
self.IMAGE_TOKEN = "(<image>./</image>)"
|
243
244
|
|
244
245
|
@staticmethod
|
245
246
|
def _process_images_task(images, input_text):
|
@@ -271,7 +272,7 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
271
272
|
async def process_images_async(
|
272
273
|
self,
|
273
274
|
image_data: List[Union[str, bytes]],
|
274
|
-
|
275
|
+
input_ids,
|
275
276
|
request_obj,
|
276
277
|
max_req_input_len,
|
277
278
|
):
|
@@ -282,28 +283,49 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
282
283
|
image_data = [image_data]
|
283
284
|
|
284
285
|
image_hashes, image_sizes = [], []
|
285
|
-
|
286
|
-
IMAGE_TOKEN = "(<image>./</image>)"
|
286
|
+
all_frames = []
|
287
287
|
|
288
|
-
# roughly calculate the max number of frames
|
289
|
-
# TODO: the process should be applied to all the visual inputs
|
288
|
+
# roughly calculate the max number of frames under the max_req_input_len limit
|
290
289
|
def calculate_max_num_frames() -> int:
|
291
290
|
# Model-specific
|
292
291
|
NUM_TOKEN_PER_FRAME = 330
|
293
292
|
|
294
|
-
ret = (max_req_input_len - len(
|
293
|
+
ret = (max_req_input_len - len(input_ids)) // NUM_TOKEN_PER_FRAME
|
295
294
|
return min(ret, 100)
|
296
295
|
|
297
|
-
# if cuda OOM set a smaller number
|
298
296
|
MAX_NUM_FRAMES = calculate_max_num_frames()
|
299
|
-
print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
|
300
297
|
|
301
|
-
|
298
|
+
# print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}")
|
299
|
+
|
300
|
+
def get_estimated_frames_list():
|
301
|
+
"""
|
302
|
+
estimate the total frame count from all visual input
|
303
|
+
"""
|
304
|
+
# Before processing inputs
|
305
|
+
estimated_frames_list = []
|
306
|
+
for image in image_data:
|
307
|
+
if isinstance(image, str) and image.startswith("video:"):
|
308
|
+
path = image[len("video:") :]
|
309
|
+
# Estimate frames for the video
|
310
|
+
vr = VideoReader(path, ctx=cpu(0))
|
311
|
+
num_frames = len(vr)
|
312
|
+
else:
|
313
|
+
# For images, each contributes one frame
|
314
|
+
num_frames = 1
|
315
|
+
estimated_frames_list.append(num_frames)
|
316
|
+
|
317
|
+
return estimated_frames_list
|
318
|
+
|
319
|
+
estimated_frames_list = get_estimated_frames_list()
|
320
|
+
total_frame_count = sum(estimated_frames_list)
|
321
|
+
scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count)
|
322
|
+
|
323
|
+
def encode_video(video_path, frame_count_limit=None):
|
302
324
|
if not os.path.exists(video_path):
|
303
325
|
logger.error(f"Video {video_path} does not exist")
|
304
326
|
return []
|
305
327
|
|
306
|
-
if
|
328
|
+
if frame_count_limit == 0:
|
307
329
|
return []
|
308
330
|
|
309
331
|
def uniform_sample(l, n):
|
@@ -314,45 +336,63 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
314
336
|
vr = VideoReader(video_path, ctx=cpu(0))
|
315
337
|
sample_fps = round(vr.get_avg_fps() / 1) # FPS
|
316
338
|
frame_idx = [i for i in range(0, len(vr), sample_fps)]
|
317
|
-
if len(frame_idx) >
|
318
|
-
frame_idx = uniform_sample(frame_idx,
|
339
|
+
if frame_count_limit is not None and len(frame_idx) > frame_count_limit:
|
340
|
+
frame_idx = uniform_sample(frame_idx, frame_count_limit)
|
319
341
|
frames = vr.get_batch(frame_idx).asnumpy()
|
320
342
|
frames = [Image.fromarray(v.astype("uint8")) for v in frames]
|
321
343
|
return frames
|
322
344
|
|
323
|
-
if isinstance(
|
324
|
-
assert len(
|
325
|
-
input_text = self._processor.tokenizer.decode(
|
326
|
-
|
345
|
+
if isinstance(input_ids, list):
|
346
|
+
assert len(input_ids) and isinstance(input_ids[0], int)
|
347
|
+
input_text = self._processor.tokenizer.decode(input_ids)
|
348
|
+
else:
|
349
|
+
input_text = input_ids
|
327
350
|
# MiniCPMV requires each frame of video as a single image token
|
328
|
-
text_parts = input_text.split(IMAGE_TOKEN)
|
351
|
+
text_parts = input_text.split(self.IMAGE_TOKEN)
|
329
352
|
new_text_parts = []
|
330
353
|
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
354
|
+
# Process each input with allocated frames
|
355
|
+
for image_index, (image, estimated_frames) in enumerate(
|
356
|
+
zip(image_data, estimated_frames_list)
|
357
|
+
):
|
358
|
+
if len(all_frames) >= MAX_NUM_FRAMES:
|
359
|
+
frames_to_process = 0
|
360
|
+
else:
|
361
|
+
frames_to_process = max(1, int(estimated_frames * scaling_factor))
|
362
|
+
|
363
|
+
if frames_to_process == 0:
|
364
|
+
frames = []
|
365
|
+
else:
|
366
|
+
try:
|
367
|
+
if isinstance(image, str) and image.startswith("video:"):
|
368
|
+
path = image[len("video:") :]
|
369
|
+
frames = encode_video(path, frame_count_limit=frames_to_process)
|
370
|
+
else:
|
371
|
+
raw_image, _size = load_image(image)
|
372
|
+
frames = [raw_image]
|
373
|
+
if len(frames) == 0:
|
374
|
+
continue
|
375
|
+
except FileNotFoundError as e:
|
376
|
+
print(e)
|
377
|
+
return None
|
378
|
+
image_sizes += frames[0].size * len(frames)
|
379
|
+
image_hashes += [hash(image)] * len(frames)
|
380
|
+
all_frames += frames
|
381
|
+
|
382
|
+
assert frames_to_process == len(frames)
|
383
|
+
|
348
384
|
new_text_parts.append(text_parts[image_index])
|
349
|
-
|
385
|
+
|
386
|
+
if frames_to_process != 0:
|
387
|
+
new_text_parts.append(self.IMAGE_TOKEN * len(frames))
|
350
388
|
|
351
389
|
new_text_parts.append(text_parts[-1])
|
390
|
+
|
352
391
|
input_text = "".join(new_text_parts)
|
353
|
-
|
392
|
+
|
393
|
+
if len(all_frames) == 0:
|
354
394
|
return None
|
355
|
-
res = await self._process_images(images=
|
395
|
+
res = await self._process_images(images=all_frames, input_text=input_text)
|
356
396
|
pixel_values = res["pixel_values"]
|
357
397
|
tgt_sizes = res["tgt_sizes"]
|
358
398
|
input_ids = res["input_ids"]
|
@@ -364,7 +404,6 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
|
|
364
404
|
if tokenizer.slice_start_id:
|
365
405
|
slice_start_id = [tokenizer.slice_start_id]
|
366
406
|
slice_end_id = [tokenizer.slice_end_id]
|
367
|
-
|
368
407
|
return {
|
369
408
|
"input_ids": input_ids.flatten().tolist(),
|
370
409
|
"pixel_values": pixel_values,
|