fastembed-bio 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastembed/__init__.py +24 -0
- fastembed/bio/__init__.py +3 -0
- fastembed/bio/protein_embedding.py +456 -0
- fastembed/common/__init__.py +3 -0
- fastembed/common/model_description.py +52 -0
- fastembed/common/model_management.py +471 -0
- fastembed/common/onnx_model.py +188 -0
- fastembed/common/preprocessor_utils.py +84 -0
- fastembed/common/types.py +27 -0
- fastembed/common/utils.py +69 -0
- fastembed/embedding.py +24 -0
- fastembed/image/__init__.py +3 -0
- fastembed/image/image_embedding.py +135 -0
- fastembed/image/image_embedding_base.py +55 -0
- fastembed/image/onnx_embedding.py +217 -0
- fastembed/image/onnx_image_model.py +156 -0
- fastembed/image/transform/functional.py +221 -0
- fastembed/image/transform/operators.py +499 -0
- fastembed/late_interaction/__init__.py +5 -0
- fastembed/late_interaction/colbert.py +301 -0
- fastembed/late_interaction/jina_colbert.py +58 -0
- fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
- fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
- fastembed/late_interaction/token_embeddings.py +83 -0
- fastembed/late_interaction_multimodal/__init__.py +5 -0
- fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
- fastembed/late_interaction_multimodal/colpali.py +327 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
- fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
- fastembed/parallel_processor.py +253 -0
- fastembed/postprocess/__init__.py +3 -0
- fastembed/postprocess/muvera.py +362 -0
- fastembed/py.typed +1 -0
- fastembed/rerank/cross_encoder/__init__.py +3 -0
- fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
- fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
- fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
- fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
- fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
- fastembed/sparse/__init__.py +4 -0
- fastembed/sparse/bm25.py +359 -0
- fastembed/sparse/bm42.py +369 -0
- fastembed/sparse/minicoil.py +372 -0
- fastembed/sparse/sparse_embedding_base.py +90 -0
- fastembed/sparse/sparse_text_embedding.py +143 -0
- fastembed/sparse/splade_pp.py +196 -0
- fastembed/sparse/utils/minicoil_encoder.py +146 -0
- fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
- fastembed/sparse/utils/tokenizer.py +120 -0
- fastembed/sparse/utils/vocab_resolver.py +202 -0
- fastembed/text/__init__.py +3 -0
- fastembed/text/clip_embedding.py +56 -0
- fastembed/text/custom_text_embedding.py +97 -0
- fastembed/text/multitask_embedding.py +109 -0
- fastembed/text/onnx_embedding.py +353 -0
- fastembed/text/onnx_text_model.py +180 -0
- fastembed/text/pooled_embedding.py +136 -0
- fastembed/text/pooled_normalized_embedding.py +164 -0
- fastembed/text/text_embedding.py +228 -0
- fastembed/text/text_embedding_base.py +75 -0
- fastembed_bio-0.1.0.dist-info/METADATA +339 -0
- fastembed_bio-0.1.0.dist-info/RECORD +66 -0
- fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
- fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
- fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
|
@@ -0,0 +1,499 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
import math
|
|
3
|
+
|
|
4
|
+
from PIL import Image
|
|
5
|
+
|
|
6
|
+
from fastembed.common.types import NumpyArray
|
|
7
|
+
from fastembed.image.transform.functional import (
|
|
8
|
+
center_crop,
|
|
9
|
+
convert_to_rgb,
|
|
10
|
+
crop_ndarray,
|
|
11
|
+
normalize,
|
|
12
|
+
pil2ndarray,
|
|
13
|
+
rescale,
|
|
14
|
+
resize,
|
|
15
|
+
resize_longest_edge,
|
|
16
|
+
resize_ndarray,
|
|
17
|
+
pad2square,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Transform:
|
|
22
|
+
def __call__(self, images: list[Any]) -> list[Image.Image] | list[NumpyArray]:
|
|
23
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class ConvertToRGB(Transform):
|
|
27
|
+
def __call__(self, images: list[Image.Image]) -> list[Image.Image]:
|
|
28
|
+
return [convert_to_rgb(image=image) for image in images]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class CenterCrop(Transform):
|
|
32
|
+
def __init__(self, size: tuple[int, int]):
|
|
33
|
+
self.size = size
|
|
34
|
+
|
|
35
|
+
def __call__(self, images: list[Image.Image]) -> list[NumpyArray]:
|
|
36
|
+
return [center_crop(image=image, size=self.size) for image in images]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class Normalize(Transform):
|
|
40
|
+
def __init__(self, mean: float | list[float], std: float | list[float]):
|
|
41
|
+
self.mean = mean
|
|
42
|
+
self.std = std
|
|
43
|
+
|
|
44
|
+
def __call__( # type: ignore[override]
|
|
45
|
+
self, images: list[NumpyArray] | list[list[NumpyArray]]
|
|
46
|
+
) -> list[NumpyArray] | list[list[NumpyArray]]:
|
|
47
|
+
if images and isinstance(images[0], list):
|
|
48
|
+
# Nested structure from ImageSplitter
|
|
49
|
+
return [
|
|
50
|
+
[normalize(image, mean=self.mean, std=self.std) for image in img_patches] # type: ignore[arg-type]
|
|
51
|
+
for img_patches in images
|
|
52
|
+
]
|
|
53
|
+
else:
|
|
54
|
+
# Flat structure (backward compatibility)
|
|
55
|
+
return [normalize(image, mean=self.mean, std=self.std) for image in images] # type: ignore[arg-type]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class Resize(Transform):
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
size: int | tuple[int, int],
|
|
62
|
+
resample: Image.Resampling = Image.Resampling.BICUBIC,
|
|
63
|
+
):
|
|
64
|
+
self.size = size
|
|
65
|
+
self.resample = resample
|
|
66
|
+
|
|
67
|
+
def __call__(self, images: list[Image.Image]) -> list[Image.Image]:
|
|
68
|
+
return [resize(image, size=self.size, resample=self.resample) for image in images]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class Rescale(Transform):
|
|
72
|
+
def __init__(self, scale: float = 1 / 255):
|
|
73
|
+
self.scale = scale
|
|
74
|
+
|
|
75
|
+
def __call__( # type: ignore[override]
|
|
76
|
+
self, images: list[NumpyArray] | list[list[NumpyArray]]
|
|
77
|
+
) -> list[NumpyArray] | list[list[NumpyArray]]:
|
|
78
|
+
if images and isinstance(images[0], list):
|
|
79
|
+
# Nested structure from ImageSplitter
|
|
80
|
+
return [
|
|
81
|
+
[rescale(image, scale=self.scale) for image in img_patches] # type: ignore[arg-type]
|
|
82
|
+
for img_patches in images
|
|
83
|
+
]
|
|
84
|
+
else:
|
|
85
|
+
# Flat structure (backward compatibility)
|
|
86
|
+
return [rescale(image, scale=self.scale) for image in images] # type: ignore[arg-type]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class PILtoNDarray(Transform):
|
|
90
|
+
def __call__(self, images: list[Image.Image | NumpyArray]) -> list[NumpyArray]:
|
|
91
|
+
return [pil2ndarray(image) for image in images]
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class PadtoSquare(Transform):
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
size: int,
|
|
98
|
+
fill_color: str | int | tuple[int, ...],
|
|
99
|
+
):
|
|
100
|
+
self.size = size
|
|
101
|
+
self.fill_color = fill_color
|
|
102
|
+
|
|
103
|
+
def __call__(self, images: list[Image.Image]) -> list[Image.Image]:
|
|
104
|
+
return [
|
|
105
|
+
pad2square(image=image, size=self.size, fill_color=self.fill_color) for image in images
|
|
106
|
+
]
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class ResizeLongestEdge(Transform):
|
|
110
|
+
"""Resize images so the longest edge equals target size, preserving aspect ratio."""
|
|
111
|
+
|
|
112
|
+
def __init__(
|
|
113
|
+
self,
|
|
114
|
+
size: int,
|
|
115
|
+
resample: Image.Resampling = Image.Resampling.LANCZOS,
|
|
116
|
+
):
|
|
117
|
+
self.size = size
|
|
118
|
+
self.resample = resample
|
|
119
|
+
|
|
120
|
+
def __call__(self, images: list[Image.Image]) -> list[Image.Image]:
|
|
121
|
+
return [resize_longest_edge(image, self.size, self.resample) for image in images]
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class ResizeForVisionEncoder(Transform):
|
|
125
|
+
"""
|
|
126
|
+
Resize both dimensions to be multiples of vision_encoder_max_size.
|
|
127
|
+
Preserves aspect ratio approximately.
|
|
128
|
+
Works on numpy arrays in (C, H, W) format.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
def __init__(
|
|
132
|
+
self,
|
|
133
|
+
max_size: int,
|
|
134
|
+
resample: Image.Resampling = Image.Resampling.LANCZOS,
|
|
135
|
+
):
|
|
136
|
+
self.max_size = max_size
|
|
137
|
+
self.resample = resample
|
|
138
|
+
|
|
139
|
+
def __call__(self, images: list[NumpyArray]) -> list[NumpyArray]:
|
|
140
|
+
result = []
|
|
141
|
+
for image in images:
|
|
142
|
+
# Assume (C, H, W) format
|
|
143
|
+
_, height, width = image.shape
|
|
144
|
+
|
|
145
|
+
aspect_ratio = width / height
|
|
146
|
+
|
|
147
|
+
if width >= height:
|
|
148
|
+
# Calculate new width as multiple of max_size
|
|
149
|
+
new_width = math.ceil(width / self.max_size) * self.max_size
|
|
150
|
+
new_height = int(new_width / aspect_ratio)
|
|
151
|
+
new_height = math.ceil(new_height / self.max_size) * self.max_size
|
|
152
|
+
else:
|
|
153
|
+
# Calculate new height as multiple of max_size
|
|
154
|
+
new_height = math.ceil(height / self.max_size) * self.max_size
|
|
155
|
+
new_width = int(new_height * aspect_ratio)
|
|
156
|
+
new_width = math.ceil(new_width / self.max_size) * self.max_size
|
|
157
|
+
|
|
158
|
+
# Resize using the ndarray resize function
|
|
159
|
+
resized = resize_ndarray(
|
|
160
|
+
image,
|
|
161
|
+
size=(new_width, new_height), # PIL expects (width, height)
|
|
162
|
+
resample=self.resample,
|
|
163
|
+
channel_first=True,
|
|
164
|
+
)
|
|
165
|
+
result.append(resized)
|
|
166
|
+
|
|
167
|
+
return result
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class ImageSplitter(Transform):
|
|
171
|
+
"""
|
|
172
|
+
Split images into grid of patches plus a global view.
|
|
173
|
+
|
|
174
|
+
If image dimensions exceed max_size:
|
|
175
|
+
- Divide into ceil(H/max_size) x ceil(W/max_size) patches
|
|
176
|
+
- Each patch is cropped from the image
|
|
177
|
+
- Add a global view (original resized to max_size x max_size)
|
|
178
|
+
|
|
179
|
+
If image is smaller than max_size:
|
|
180
|
+
- Return single image unchanged
|
|
181
|
+
|
|
182
|
+
Works on numpy arrays in (C, H, W) format.
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
def __init__(
|
|
186
|
+
self,
|
|
187
|
+
max_size: int,
|
|
188
|
+
resample: Image.Resampling = Image.Resampling.LANCZOS,
|
|
189
|
+
):
|
|
190
|
+
self.max_size = max_size
|
|
191
|
+
self.resample = resample
|
|
192
|
+
|
|
193
|
+
def __call__(self, images: list[NumpyArray]) -> list[list[NumpyArray]]: # type: ignore[override]
|
|
194
|
+
result = []
|
|
195
|
+
|
|
196
|
+
for image in images:
|
|
197
|
+
# Assume (C, H, W) format
|
|
198
|
+
_, height, width = image.shape
|
|
199
|
+
max_height = max_width = self.max_size
|
|
200
|
+
|
|
201
|
+
frames = []
|
|
202
|
+
|
|
203
|
+
if height > max_height or width > max_width:
|
|
204
|
+
# Calculate the number of splits needed
|
|
205
|
+
num_splits_h = math.ceil(height / max_height)
|
|
206
|
+
num_splits_w = math.ceil(width / max_width)
|
|
207
|
+
|
|
208
|
+
# Calculate optimal patch dimensions
|
|
209
|
+
optimal_height = math.ceil(height / num_splits_h)
|
|
210
|
+
optimal_width = math.ceil(width / num_splits_w)
|
|
211
|
+
|
|
212
|
+
# Generate patches in grid order (row by row)
|
|
213
|
+
for r in range(num_splits_h):
|
|
214
|
+
for c in range(num_splits_w):
|
|
215
|
+
# Calculate crop coordinates
|
|
216
|
+
start_x = c * optimal_width
|
|
217
|
+
start_y = r * optimal_height
|
|
218
|
+
end_x = min(start_x + optimal_width, width)
|
|
219
|
+
end_y = min(start_y + optimal_height, height)
|
|
220
|
+
|
|
221
|
+
# Crop the patch
|
|
222
|
+
cropped = crop_ndarray(
|
|
223
|
+
image, x1=start_x, y1=start_y, x2=end_x, y2=end_y, channel_first=True
|
|
224
|
+
)
|
|
225
|
+
frames.append(cropped)
|
|
226
|
+
|
|
227
|
+
# Add global view (resized to max_size x max_size)
|
|
228
|
+
global_view = resize_ndarray(
|
|
229
|
+
image,
|
|
230
|
+
size=(max_width, max_height), # PIL expects (width, height)
|
|
231
|
+
resample=self.resample,
|
|
232
|
+
channel_first=True,
|
|
233
|
+
)
|
|
234
|
+
frames.append(global_view)
|
|
235
|
+
else:
|
|
236
|
+
# Image is small enough, no splitting needed
|
|
237
|
+
frames.append(image)
|
|
238
|
+
|
|
239
|
+
# Append (not extend) to preserve per-image grouping
|
|
240
|
+
result.append(frames)
|
|
241
|
+
|
|
242
|
+
return result
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
class SquareResize(Transform):
|
|
246
|
+
"""
|
|
247
|
+
Resize images to square dimensions (max_size x max_size).
|
|
248
|
+
Works on numpy arrays in (C, H, W) format.
|
|
249
|
+
"""
|
|
250
|
+
|
|
251
|
+
def __init__(
|
|
252
|
+
self,
|
|
253
|
+
size: int,
|
|
254
|
+
resample: Image.Resampling = Image.Resampling.LANCZOS,
|
|
255
|
+
):
|
|
256
|
+
self.size = size
|
|
257
|
+
self.resample = resample
|
|
258
|
+
|
|
259
|
+
def __call__(self, images: list[NumpyArray]) -> list[list[NumpyArray]]: # type: ignore[override]
|
|
260
|
+
return [
|
|
261
|
+
[
|
|
262
|
+
resize_ndarray(
|
|
263
|
+
image, size=(self.size, self.size), resample=self.resample, channel_first=True
|
|
264
|
+
)
|
|
265
|
+
]
|
|
266
|
+
for image in images
|
|
267
|
+
]
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class Compose:
|
|
271
|
+
def __init__(self, transforms: list[Transform]):
|
|
272
|
+
self.transforms = transforms
|
|
273
|
+
|
|
274
|
+
def __call__(
|
|
275
|
+
self, images: list[Image.Image] | list[NumpyArray]
|
|
276
|
+
) -> list[NumpyArray] | list[Image.Image]:
|
|
277
|
+
for transform in self.transforms:
|
|
278
|
+
images = transform(images)
|
|
279
|
+
return images
|
|
280
|
+
|
|
281
|
+
@classmethod
|
|
282
|
+
def from_config(cls, config: dict[str, Any]) -> "Compose":
|
|
283
|
+
"""Creates processor from a config dict.
|
|
284
|
+
Args:
|
|
285
|
+
config (dict[str, Any]): Configuration dictionary.
|
|
286
|
+
|
|
287
|
+
Valid keys:
|
|
288
|
+
- do_resize
|
|
289
|
+
- resize_mode
|
|
290
|
+
- size
|
|
291
|
+
- fill_color
|
|
292
|
+
- do_center_crop
|
|
293
|
+
- crop_size
|
|
294
|
+
- do_rescale
|
|
295
|
+
- rescale_factor
|
|
296
|
+
- do_normalize
|
|
297
|
+
- image_mean
|
|
298
|
+
- mean
|
|
299
|
+
- image_std
|
|
300
|
+
- std
|
|
301
|
+
- resample
|
|
302
|
+
- interpolation
|
|
303
|
+
Valid size keys (nested):
|
|
304
|
+
- {"height", "width"}
|
|
305
|
+
- {"shortest_edge"}
|
|
306
|
+
- {"longest_edge"}
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
Compose: Image processor.
|
|
310
|
+
"""
|
|
311
|
+
transforms: list[Transform] = []
|
|
312
|
+
cls._get_convert_to_rgb(transforms, config)
|
|
313
|
+
cls._get_resize(transforms, config)
|
|
314
|
+
cls._get_pad2square(transforms, config)
|
|
315
|
+
cls._get_center_crop(transforms, config)
|
|
316
|
+
cls._get_pil2ndarray(transforms, config)
|
|
317
|
+
cls._get_image_splitting(transforms, config)
|
|
318
|
+
cls._get_rescale(transforms, config)
|
|
319
|
+
cls._get_normalize(transforms, config)
|
|
320
|
+
return cls(transforms=transforms)
|
|
321
|
+
|
|
322
|
+
@staticmethod
|
|
323
|
+
def _get_convert_to_rgb(transforms: list[Transform], config: dict[str, Any]) -> None:
|
|
324
|
+
transforms.append(ConvertToRGB())
|
|
325
|
+
|
|
326
|
+
@classmethod
|
|
327
|
+
def _get_resize(cls, transforms: list[Transform], config: dict[str, Any]) -> None:
|
|
328
|
+
mode = config.get("image_processor_type", "CLIPImageProcessor")
|
|
329
|
+
if mode in ("CLIPImageProcessor", "SiglipImageProcessor"):
|
|
330
|
+
if config.get("do_resize", False):
|
|
331
|
+
size = config["size"]
|
|
332
|
+
if "shortest_edge" in size:
|
|
333
|
+
size = size["shortest_edge"]
|
|
334
|
+
elif "height" in size and "width" in size:
|
|
335
|
+
size = (size["height"], size["width"])
|
|
336
|
+
else:
|
|
337
|
+
raise ValueError(
|
|
338
|
+
"Size must contain either 'shortest_edge' or 'height' and 'width'."
|
|
339
|
+
)
|
|
340
|
+
transforms.append(
|
|
341
|
+
Resize(
|
|
342
|
+
size=size,
|
|
343
|
+
resample=config.get("resample", Image.Resampling.BICUBIC),
|
|
344
|
+
)
|
|
345
|
+
)
|
|
346
|
+
elif mode == "ConvNextFeatureExtractor":
|
|
347
|
+
if "size" in config and "shortest_edge" not in config["size"]:
|
|
348
|
+
raise ValueError(
|
|
349
|
+
f"Size dictionary must contain 'shortest_edge' key. Got {config['size'].keys()}"
|
|
350
|
+
)
|
|
351
|
+
shortest_edge = config["size"]["shortest_edge"]
|
|
352
|
+
crop_pct = config.get("crop_pct", 0.875)
|
|
353
|
+
if shortest_edge < 384:
|
|
354
|
+
# maintain same ratio, resizing shortest edge to shortest_edge/crop_pct
|
|
355
|
+
resize_shortest_edge = int(shortest_edge / crop_pct)
|
|
356
|
+
transforms.append(
|
|
357
|
+
Resize(
|
|
358
|
+
size=resize_shortest_edge,
|
|
359
|
+
resample=config.get("resample", Image.Resampling.BICUBIC),
|
|
360
|
+
)
|
|
361
|
+
)
|
|
362
|
+
transforms.append(CenterCrop(size=(shortest_edge, shortest_edge)))
|
|
363
|
+
else:
|
|
364
|
+
transforms.append(
|
|
365
|
+
Resize(
|
|
366
|
+
size=(shortest_edge, shortest_edge),
|
|
367
|
+
resample=config.get("resample", Image.Resampling.BICUBIC),
|
|
368
|
+
)
|
|
369
|
+
)
|
|
370
|
+
elif mode == "JinaCLIPImageProcessor":
|
|
371
|
+
interpolation = config.get("interpolation")
|
|
372
|
+
if isinstance(interpolation, str):
|
|
373
|
+
resample = cls._interpolation_resolver(interpolation)
|
|
374
|
+
else:
|
|
375
|
+
resample = interpolation or Image.Resampling.BICUBIC
|
|
376
|
+
|
|
377
|
+
if "size" in config:
|
|
378
|
+
resize_mode = config.get("resize_mode", "shortest")
|
|
379
|
+
if resize_mode == "shortest":
|
|
380
|
+
transforms.append(
|
|
381
|
+
Resize(
|
|
382
|
+
size=config["size"],
|
|
383
|
+
resample=resample,
|
|
384
|
+
)
|
|
385
|
+
)
|
|
386
|
+
elif mode == "Idefics3ImageProcessor":
|
|
387
|
+
if config.get("do_resize", False):
|
|
388
|
+
size = config.get("size", {})
|
|
389
|
+
if "longest_edge" not in size:
|
|
390
|
+
raise ValueError(
|
|
391
|
+
"Size dictionary must contain 'longest_edge' key for Idefics3ImageProcessor"
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
# Handle resample parameter - can be int enum or PIL.Image.Resampling
|
|
395
|
+
resample = config.get("resample", Image.Resampling.LANCZOS)
|
|
396
|
+
if isinstance(resample, int):
|
|
397
|
+
resample = Image.Resampling(resample)
|
|
398
|
+
|
|
399
|
+
transforms.append(
|
|
400
|
+
ResizeLongestEdge(
|
|
401
|
+
size=size["longest_edge"],
|
|
402
|
+
resample=resample,
|
|
403
|
+
)
|
|
404
|
+
)
|
|
405
|
+
else:
|
|
406
|
+
raise ValueError(f"Preprocessor {mode} is not supported")
|
|
407
|
+
|
|
408
|
+
@staticmethod
|
|
409
|
+
def _get_center_crop(transforms: list[Transform], config: dict[str, Any]) -> None:
|
|
410
|
+
mode = config.get("image_processor_type", "CLIPImageProcessor")
|
|
411
|
+
if mode in ("CLIPImageProcessor", "SiglipImageProcessor"):
|
|
412
|
+
if config.get("do_center_crop", False):
|
|
413
|
+
crop_size_raw = config["crop_size"]
|
|
414
|
+
crop_size: tuple[int, int]
|
|
415
|
+
if isinstance(crop_size_raw, int):
|
|
416
|
+
crop_size = (crop_size_raw, crop_size_raw)
|
|
417
|
+
elif isinstance(crop_size_raw, dict):
|
|
418
|
+
crop_size = (crop_size_raw["height"], crop_size_raw["width"])
|
|
419
|
+
else:
|
|
420
|
+
raise ValueError(f"Invalid crop size: {crop_size_raw}")
|
|
421
|
+
transforms.append(CenterCrop(size=crop_size))
|
|
422
|
+
elif mode == "ConvNextFeatureExtractor":
|
|
423
|
+
pass
|
|
424
|
+
elif mode == "JinaCLIPImageProcessor":
|
|
425
|
+
pass
|
|
426
|
+
elif mode == "Idefics3ImageProcessor":
|
|
427
|
+
pass
|
|
428
|
+
else:
|
|
429
|
+
raise ValueError(f"Preprocessor {mode} is not supported")
|
|
430
|
+
|
|
431
|
+
@staticmethod
|
|
432
|
+
def _get_pil2ndarray(transforms: list[Transform], config: dict[str, Any]) -> None:
|
|
433
|
+
transforms.append(PILtoNDarray())
|
|
434
|
+
|
|
435
|
+
@classmethod
|
|
436
|
+
def _get_image_splitting(cls, transforms: list[Transform], config: dict[str, Any]) -> None:
|
|
437
|
+
"""
|
|
438
|
+
Add image splitting transforms for Idefics3.
|
|
439
|
+
Handles conditional logic: splitting vs square resize.
|
|
440
|
+
Must be called AFTER PILtoNDarray.
|
|
441
|
+
"""
|
|
442
|
+
mode = config.get("image_processor_type", "CLIPImageProcessor")
|
|
443
|
+
|
|
444
|
+
if mode == "Idefics3ImageProcessor":
|
|
445
|
+
do_splitting = config.get("do_image_splitting", False)
|
|
446
|
+
max_size = config.get("max_image_size", {}).get("longest_edge", 512)
|
|
447
|
+
resample = config.get("resample", Image.Resampling.LANCZOS)
|
|
448
|
+
if isinstance(resample, int):
|
|
449
|
+
resample = Image.Resampling(resample)
|
|
450
|
+
|
|
451
|
+
if do_splitting:
|
|
452
|
+
transforms.append(ResizeForVisionEncoder(max_size, resample))
|
|
453
|
+
transforms.append(ImageSplitter(max_size, resample))
|
|
454
|
+
else:
|
|
455
|
+
transforms.append(SquareResize(max_size, resample))
|
|
456
|
+
|
|
457
|
+
@staticmethod
|
|
458
|
+
def _get_rescale(transforms: list[Transform], config: dict[str, Any]) -> None:
|
|
459
|
+
if config.get("do_rescale", True):
|
|
460
|
+
rescale_factor = config.get("rescale_factor", 1 / 255)
|
|
461
|
+
transforms.append(Rescale(scale=rescale_factor))
|
|
462
|
+
|
|
463
|
+
@staticmethod
|
|
464
|
+
def _get_normalize(transforms: list[Transform], config: dict[str, Any]) -> None:
|
|
465
|
+
if config.get("do_normalize", False):
|
|
466
|
+
transforms.append(Normalize(mean=config["image_mean"], std=config["image_std"]))
|
|
467
|
+
elif "mean" in config and "std" in config:
|
|
468
|
+
transforms.append(Normalize(mean=config["mean"], std=config["std"]))
|
|
469
|
+
|
|
470
|
+
@staticmethod
|
|
471
|
+
def _get_pad2square(transforms: list[Transform], config: dict[str, Any]) -> None:
|
|
472
|
+
mode = config.get("image_processor_type", "CLIPImageProcessor")
|
|
473
|
+
if mode == "CLIPImageProcessor":
|
|
474
|
+
pass
|
|
475
|
+
elif mode == "ConvNextFeatureExtractor":
|
|
476
|
+
pass
|
|
477
|
+
elif mode == "JinaCLIPImageProcessor":
|
|
478
|
+
transforms.append(
|
|
479
|
+
PadtoSquare(
|
|
480
|
+
size=config["size"],
|
|
481
|
+
fill_color=config.get("fill_color", 0),
|
|
482
|
+
)
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
@staticmethod
|
|
486
|
+
def _interpolation_resolver(resample: str | None = None) -> Image.Resampling:
|
|
487
|
+
interpolation_map = {
|
|
488
|
+
"nearest": Image.Resampling.NEAREST,
|
|
489
|
+
"lanczos": Image.Resampling.LANCZOS,
|
|
490
|
+
"bilinear": Image.Resampling.BILINEAR,
|
|
491
|
+
"bicubic": Image.Resampling.BICUBIC,
|
|
492
|
+
"box": Image.Resampling.BOX,
|
|
493
|
+
"hamming": Image.Resampling.HAMMING,
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
if resample and (method := interpolation_map.get(resample.lower())):
|
|
497
|
+
return method
|
|
498
|
+
|
|
499
|
+
raise ValueError(f"Unknown interpolation method: {resample}")
|