fastembed-bio 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. fastembed/__init__.py +24 -0
  2. fastembed/bio/__init__.py +3 -0
  3. fastembed/bio/protein_embedding.py +456 -0
  4. fastembed/common/__init__.py +3 -0
  5. fastembed/common/model_description.py +52 -0
  6. fastembed/common/model_management.py +471 -0
  7. fastembed/common/onnx_model.py +188 -0
  8. fastembed/common/preprocessor_utils.py +84 -0
  9. fastembed/common/types.py +27 -0
  10. fastembed/common/utils.py +69 -0
  11. fastembed/embedding.py +24 -0
  12. fastembed/image/__init__.py +3 -0
  13. fastembed/image/image_embedding.py +135 -0
  14. fastembed/image/image_embedding_base.py +55 -0
  15. fastembed/image/onnx_embedding.py +217 -0
  16. fastembed/image/onnx_image_model.py +156 -0
  17. fastembed/image/transform/functional.py +221 -0
  18. fastembed/image/transform/operators.py +499 -0
  19. fastembed/late_interaction/__init__.py +5 -0
  20. fastembed/late_interaction/colbert.py +301 -0
  21. fastembed/late_interaction/jina_colbert.py +58 -0
  22. fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
  23. fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
  24. fastembed/late_interaction/token_embeddings.py +83 -0
  25. fastembed/late_interaction_multimodal/__init__.py +5 -0
  26. fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
  27. fastembed/late_interaction_multimodal/colpali.py +327 -0
  28. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
  29. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
  30. fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
  31. fastembed/parallel_processor.py +253 -0
  32. fastembed/postprocess/__init__.py +3 -0
  33. fastembed/postprocess/muvera.py +362 -0
  34. fastembed/py.typed +1 -0
  35. fastembed/rerank/cross_encoder/__init__.py +3 -0
  36. fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
  37. fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
  38. fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
  39. fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
  40. fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
  41. fastembed/sparse/__init__.py +4 -0
  42. fastembed/sparse/bm25.py +359 -0
  43. fastembed/sparse/bm42.py +369 -0
  44. fastembed/sparse/minicoil.py +372 -0
  45. fastembed/sparse/sparse_embedding_base.py +90 -0
  46. fastembed/sparse/sparse_text_embedding.py +143 -0
  47. fastembed/sparse/splade_pp.py +196 -0
  48. fastembed/sparse/utils/minicoil_encoder.py +146 -0
  49. fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
  50. fastembed/sparse/utils/tokenizer.py +120 -0
  51. fastembed/sparse/utils/vocab_resolver.py +202 -0
  52. fastembed/text/__init__.py +3 -0
  53. fastembed/text/clip_embedding.py +56 -0
  54. fastembed/text/custom_text_embedding.py +97 -0
  55. fastembed/text/multitask_embedding.py +109 -0
  56. fastembed/text/onnx_embedding.py +353 -0
  57. fastembed/text/onnx_text_model.py +180 -0
  58. fastembed/text/pooled_embedding.py +136 -0
  59. fastembed/text/pooled_normalized_embedding.py +164 -0
  60. fastembed/text/text_embedding.py +228 -0
  61. fastembed/text/text_embedding_base.py +75 -0
  62. fastembed_bio-0.1.0.dist-info/METADATA +339 -0
  63. fastembed_bio-0.1.0.dist-info/RECORD +66 -0
  64. fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
  65. fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
  66. fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
@@ -0,0 +1,499 @@
1
+ from typing import Any
2
+ import math
3
+
4
+ from PIL import Image
5
+
6
+ from fastembed.common.types import NumpyArray
7
+ from fastembed.image.transform.functional import (
8
+ center_crop,
9
+ convert_to_rgb,
10
+ crop_ndarray,
11
+ normalize,
12
+ pil2ndarray,
13
+ rescale,
14
+ resize,
15
+ resize_longest_edge,
16
+ resize_ndarray,
17
+ pad2square,
18
+ )
19
+
20
+
21
+ class Transform:
22
+ def __call__(self, images: list[Any]) -> list[Image.Image] | list[NumpyArray]:
23
+ raise NotImplementedError("Subclasses must implement this method")
24
+
25
+
26
+ class ConvertToRGB(Transform):
27
+ def __call__(self, images: list[Image.Image]) -> list[Image.Image]:
28
+ return [convert_to_rgb(image=image) for image in images]
29
+
30
+
31
+ class CenterCrop(Transform):
32
+ def __init__(self, size: tuple[int, int]):
33
+ self.size = size
34
+
35
+ def __call__(self, images: list[Image.Image]) -> list[NumpyArray]:
36
+ return [center_crop(image=image, size=self.size) for image in images]
37
+
38
+
39
+ class Normalize(Transform):
40
+ def __init__(self, mean: float | list[float], std: float | list[float]):
41
+ self.mean = mean
42
+ self.std = std
43
+
44
+ def __call__( # type: ignore[override]
45
+ self, images: list[NumpyArray] | list[list[NumpyArray]]
46
+ ) -> list[NumpyArray] | list[list[NumpyArray]]:
47
+ if images and isinstance(images[0], list):
48
+ # Nested structure from ImageSplitter
49
+ return [
50
+ [normalize(image, mean=self.mean, std=self.std) for image in img_patches] # type: ignore[arg-type]
51
+ for img_patches in images
52
+ ]
53
+ else:
54
+ # Flat structure (backward compatibility)
55
+ return [normalize(image, mean=self.mean, std=self.std) for image in images] # type: ignore[arg-type]
56
+
57
+
58
+ class Resize(Transform):
59
+ def __init__(
60
+ self,
61
+ size: int | tuple[int, int],
62
+ resample: Image.Resampling = Image.Resampling.BICUBIC,
63
+ ):
64
+ self.size = size
65
+ self.resample = resample
66
+
67
+ def __call__(self, images: list[Image.Image]) -> list[Image.Image]:
68
+ return [resize(image, size=self.size, resample=self.resample) for image in images]
69
+
70
+
71
+ class Rescale(Transform):
72
+ def __init__(self, scale: float = 1 / 255):
73
+ self.scale = scale
74
+
75
+ def __call__( # type: ignore[override]
76
+ self, images: list[NumpyArray] | list[list[NumpyArray]]
77
+ ) -> list[NumpyArray] | list[list[NumpyArray]]:
78
+ if images and isinstance(images[0], list):
79
+ # Nested structure from ImageSplitter
80
+ return [
81
+ [rescale(image, scale=self.scale) for image in img_patches] # type: ignore[arg-type]
82
+ for img_patches in images
83
+ ]
84
+ else:
85
+ # Flat structure (backward compatibility)
86
+ return [rescale(image, scale=self.scale) for image in images] # type: ignore[arg-type]
87
+
88
+
89
+ class PILtoNDarray(Transform):
90
+ def __call__(self, images: list[Image.Image | NumpyArray]) -> list[NumpyArray]:
91
+ return [pil2ndarray(image) for image in images]
92
+
93
+
94
+ class PadtoSquare(Transform):
95
+ def __init__(
96
+ self,
97
+ size: int,
98
+ fill_color: str | int | tuple[int, ...],
99
+ ):
100
+ self.size = size
101
+ self.fill_color = fill_color
102
+
103
+ def __call__(self, images: list[Image.Image]) -> list[Image.Image]:
104
+ return [
105
+ pad2square(image=image, size=self.size, fill_color=self.fill_color) for image in images
106
+ ]
107
+
108
+
109
+ class ResizeLongestEdge(Transform):
110
+ """Resize images so the longest edge equals target size, preserving aspect ratio."""
111
+
112
+ def __init__(
113
+ self,
114
+ size: int,
115
+ resample: Image.Resampling = Image.Resampling.LANCZOS,
116
+ ):
117
+ self.size = size
118
+ self.resample = resample
119
+
120
+ def __call__(self, images: list[Image.Image]) -> list[Image.Image]:
121
+ return [resize_longest_edge(image, self.size, self.resample) for image in images]
122
+
123
+
124
+ class ResizeForVisionEncoder(Transform):
125
+ """
126
+ Resize both dimensions to be multiples of vision_encoder_max_size.
127
+ Preserves aspect ratio approximately.
128
+ Works on numpy arrays in (C, H, W) format.
129
+ """
130
+
131
+ def __init__(
132
+ self,
133
+ max_size: int,
134
+ resample: Image.Resampling = Image.Resampling.LANCZOS,
135
+ ):
136
+ self.max_size = max_size
137
+ self.resample = resample
138
+
139
+ def __call__(self, images: list[NumpyArray]) -> list[NumpyArray]:
140
+ result = []
141
+ for image in images:
142
+ # Assume (C, H, W) format
143
+ _, height, width = image.shape
144
+
145
+ aspect_ratio = width / height
146
+
147
+ if width >= height:
148
+ # Calculate new width as multiple of max_size
149
+ new_width = math.ceil(width / self.max_size) * self.max_size
150
+ new_height = int(new_width / aspect_ratio)
151
+ new_height = math.ceil(new_height / self.max_size) * self.max_size
152
+ else:
153
+ # Calculate new height as multiple of max_size
154
+ new_height = math.ceil(height / self.max_size) * self.max_size
155
+ new_width = int(new_height * aspect_ratio)
156
+ new_width = math.ceil(new_width / self.max_size) * self.max_size
157
+
158
+ # Resize using the ndarray resize function
159
+ resized = resize_ndarray(
160
+ image,
161
+ size=(new_width, new_height), # PIL expects (width, height)
162
+ resample=self.resample,
163
+ channel_first=True,
164
+ )
165
+ result.append(resized)
166
+
167
+ return result
168
+
169
+
170
+ class ImageSplitter(Transform):
171
+ """
172
+ Split images into grid of patches plus a global view.
173
+
174
+ If image dimensions exceed max_size:
175
+ - Divide into ceil(H/max_size) x ceil(W/max_size) patches
176
+ - Each patch is cropped from the image
177
+ - Add a global view (original resized to max_size x max_size)
178
+
179
+ If image is smaller than max_size:
180
+ - Return single image unchanged
181
+
182
+ Works on numpy arrays in (C, H, W) format.
183
+ """
184
+
185
+ def __init__(
186
+ self,
187
+ max_size: int,
188
+ resample: Image.Resampling = Image.Resampling.LANCZOS,
189
+ ):
190
+ self.max_size = max_size
191
+ self.resample = resample
192
+
193
+ def __call__(self, images: list[NumpyArray]) -> list[list[NumpyArray]]: # type: ignore[override]
194
+ result = []
195
+
196
+ for image in images:
197
+ # Assume (C, H, W) format
198
+ _, height, width = image.shape
199
+ max_height = max_width = self.max_size
200
+
201
+ frames = []
202
+
203
+ if height > max_height or width > max_width:
204
+ # Calculate the number of splits needed
205
+ num_splits_h = math.ceil(height / max_height)
206
+ num_splits_w = math.ceil(width / max_width)
207
+
208
+ # Calculate optimal patch dimensions
209
+ optimal_height = math.ceil(height / num_splits_h)
210
+ optimal_width = math.ceil(width / num_splits_w)
211
+
212
+ # Generate patches in grid order (row by row)
213
+ for r in range(num_splits_h):
214
+ for c in range(num_splits_w):
215
+ # Calculate crop coordinates
216
+ start_x = c * optimal_width
217
+ start_y = r * optimal_height
218
+ end_x = min(start_x + optimal_width, width)
219
+ end_y = min(start_y + optimal_height, height)
220
+
221
+ # Crop the patch
222
+ cropped = crop_ndarray(
223
+ image, x1=start_x, y1=start_y, x2=end_x, y2=end_y, channel_first=True
224
+ )
225
+ frames.append(cropped)
226
+
227
+ # Add global view (resized to max_size x max_size)
228
+ global_view = resize_ndarray(
229
+ image,
230
+ size=(max_width, max_height), # PIL expects (width, height)
231
+ resample=self.resample,
232
+ channel_first=True,
233
+ )
234
+ frames.append(global_view)
235
+ else:
236
+ # Image is small enough, no splitting needed
237
+ frames.append(image)
238
+
239
+ # Append (not extend) to preserve per-image grouping
240
+ result.append(frames)
241
+
242
+ return result
243
+
244
+
245
+ class SquareResize(Transform):
246
+ """
247
+ Resize images to square dimensions (max_size x max_size).
248
+ Works on numpy arrays in (C, H, W) format.
249
+ """
250
+
251
+ def __init__(
252
+ self,
253
+ size: int,
254
+ resample: Image.Resampling = Image.Resampling.LANCZOS,
255
+ ):
256
+ self.size = size
257
+ self.resample = resample
258
+
259
+ def __call__(self, images: list[NumpyArray]) -> list[list[NumpyArray]]: # type: ignore[override]
260
+ return [
261
+ [
262
+ resize_ndarray(
263
+ image, size=(self.size, self.size), resample=self.resample, channel_first=True
264
+ )
265
+ ]
266
+ for image in images
267
+ ]
268
+
269
+
270
+ class Compose:
271
+ def __init__(self, transforms: list[Transform]):
272
+ self.transforms = transforms
273
+
274
+ def __call__(
275
+ self, images: list[Image.Image] | list[NumpyArray]
276
+ ) -> list[NumpyArray] | list[Image.Image]:
277
+ for transform in self.transforms:
278
+ images = transform(images)
279
+ return images
280
+
281
+ @classmethod
282
+ def from_config(cls, config: dict[str, Any]) -> "Compose":
283
+ """Creates processor from a config dict.
284
+ Args:
285
+ config (dict[str, Any]): Configuration dictionary.
286
+
287
+ Valid keys:
288
+ - do_resize
289
+ - resize_mode
290
+ - size
291
+ - fill_color
292
+ - do_center_crop
293
+ - crop_size
294
+ - do_rescale
295
+ - rescale_factor
296
+ - do_normalize
297
+ - image_mean
298
+ - mean
299
+ - image_std
300
+ - std
301
+ - resample
302
+ - interpolation
303
+ Valid size keys (nested):
304
+ - {"height", "width"}
305
+ - {"shortest_edge"}
306
+ - {"longest_edge"}
307
+
308
+ Returns:
309
+ Compose: Image processor.
310
+ """
311
+ transforms: list[Transform] = []
312
+ cls._get_convert_to_rgb(transforms, config)
313
+ cls._get_resize(transforms, config)
314
+ cls._get_pad2square(transforms, config)
315
+ cls._get_center_crop(transforms, config)
316
+ cls._get_pil2ndarray(transforms, config)
317
+ cls._get_image_splitting(transforms, config)
318
+ cls._get_rescale(transforms, config)
319
+ cls._get_normalize(transforms, config)
320
+ return cls(transforms=transforms)
321
+
322
+ @staticmethod
323
+ def _get_convert_to_rgb(transforms: list[Transform], config: dict[str, Any]) -> None:
324
+ transforms.append(ConvertToRGB())
325
+
326
+ @classmethod
327
+ def _get_resize(cls, transforms: list[Transform], config: dict[str, Any]) -> None:
328
+ mode = config.get("image_processor_type", "CLIPImageProcessor")
329
+ if mode in ("CLIPImageProcessor", "SiglipImageProcessor"):
330
+ if config.get("do_resize", False):
331
+ size = config["size"]
332
+ if "shortest_edge" in size:
333
+ size = size["shortest_edge"]
334
+ elif "height" in size and "width" in size:
335
+ size = (size["height"], size["width"])
336
+ else:
337
+ raise ValueError(
338
+ "Size must contain either 'shortest_edge' or 'height' and 'width'."
339
+ )
340
+ transforms.append(
341
+ Resize(
342
+ size=size,
343
+ resample=config.get("resample", Image.Resampling.BICUBIC),
344
+ )
345
+ )
346
+ elif mode == "ConvNextFeatureExtractor":
347
+ if "size" in config and "shortest_edge" not in config["size"]:
348
+ raise ValueError(
349
+ f"Size dictionary must contain 'shortest_edge' key. Got {config['size'].keys()}"
350
+ )
351
+ shortest_edge = config["size"]["shortest_edge"]
352
+ crop_pct = config.get("crop_pct", 0.875)
353
+ if shortest_edge < 384:
354
+ # maintain same ratio, resizing shortest edge to shortest_edge/crop_pct
355
+ resize_shortest_edge = int(shortest_edge / crop_pct)
356
+ transforms.append(
357
+ Resize(
358
+ size=resize_shortest_edge,
359
+ resample=config.get("resample", Image.Resampling.BICUBIC),
360
+ )
361
+ )
362
+ transforms.append(CenterCrop(size=(shortest_edge, shortest_edge)))
363
+ else:
364
+ transforms.append(
365
+ Resize(
366
+ size=(shortest_edge, shortest_edge),
367
+ resample=config.get("resample", Image.Resampling.BICUBIC),
368
+ )
369
+ )
370
+ elif mode == "JinaCLIPImageProcessor":
371
+ interpolation = config.get("interpolation")
372
+ if isinstance(interpolation, str):
373
+ resample = cls._interpolation_resolver(interpolation)
374
+ else:
375
+ resample = interpolation or Image.Resampling.BICUBIC
376
+
377
+ if "size" in config:
378
+ resize_mode = config.get("resize_mode", "shortest")
379
+ if resize_mode == "shortest":
380
+ transforms.append(
381
+ Resize(
382
+ size=config["size"],
383
+ resample=resample,
384
+ )
385
+ )
386
+ elif mode == "Idefics3ImageProcessor":
387
+ if config.get("do_resize", False):
388
+ size = config.get("size", {})
389
+ if "longest_edge" not in size:
390
+ raise ValueError(
391
+ "Size dictionary must contain 'longest_edge' key for Idefics3ImageProcessor"
392
+ )
393
+
394
+ # Handle resample parameter - can be int enum or PIL.Image.Resampling
395
+ resample = config.get("resample", Image.Resampling.LANCZOS)
396
+ if isinstance(resample, int):
397
+ resample = Image.Resampling(resample)
398
+
399
+ transforms.append(
400
+ ResizeLongestEdge(
401
+ size=size["longest_edge"],
402
+ resample=resample,
403
+ )
404
+ )
405
+ else:
406
+ raise ValueError(f"Preprocessor {mode} is not supported")
407
+
408
+ @staticmethod
409
+ def _get_center_crop(transforms: list[Transform], config: dict[str, Any]) -> None:
410
+ mode = config.get("image_processor_type", "CLIPImageProcessor")
411
+ if mode in ("CLIPImageProcessor", "SiglipImageProcessor"):
412
+ if config.get("do_center_crop", False):
413
+ crop_size_raw = config["crop_size"]
414
+ crop_size: tuple[int, int]
415
+ if isinstance(crop_size_raw, int):
416
+ crop_size = (crop_size_raw, crop_size_raw)
417
+ elif isinstance(crop_size_raw, dict):
418
+ crop_size = (crop_size_raw["height"], crop_size_raw["width"])
419
+ else:
420
+ raise ValueError(f"Invalid crop size: {crop_size_raw}")
421
+ transforms.append(CenterCrop(size=crop_size))
422
+ elif mode == "ConvNextFeatureExtractor":
423
+ pass
424
+ elif mode == "JinaCLIPImageProcessor":
425
+ pass
426
+ elif mode == "Idefics3ImageProcessor":
427
+ pass
428
+ else:
429
+ raise ValueError(f"Preprocessor {mode} is not supported")
430
+
431
+ @staticmethod
432
+ def _get_pil2ndarray(transforms: list[Transform], config: dict[str, Any]) -> None:
433
+ transforms.append(PILtoNDarray())
434
+
435
+ @classmethod
436
+ def _get_image_splitting(cls, transforms: list[Transform], config: dict[str, Any]) -> None:
437
+ """
438
+ Add image splitting transforms for Idefics3.
439
+ Handles conditional logic: splitting vs square resize.
440
+ Must be called AFTER PILtoNDarray.
441
+ """
442
+ mode = config.get("image_processor_type", "CLIPImageProcessor")
443
+
444
+ if mode == "Idefics3ImageProcessor":
445
+ do_splitting = config.get("do_image_splitting", False)
446
+ max_size = config.get("max_image_size", {}).get("longest_edge", 512)
447
+ resample = config.get("resample", Image.Resampling.LANCZOS)
448
+ if isinstance(resample, int):
449
+ resample = Image.Resampling(resample)
450
+
451
+ if do_splitting:
452
+ transforms.append(ResizeForVisionEncoder(max_size, resample))
453
+ transforms.append(ImageSplitter(max_size, resample))
454
+ else:
455
+ transforms.append(SquareResize(max_size, resample))
456
+
457
+ @staticmethod
458
+ def _get_rescale(transforms: list[Transform], config: dict[str, Any]) -> None:
459
+ if config.get("do_rescale", True):
460
+ rescale_factor = config.get("rescale_factor", 1 / 255)
461
+ transforms.append(Rescale(scale=rescale_factor))
462
+
463
+ @staticmethod
464
+ def _get_normalize(transforms: list[Transform], config: dict[str, Any]) -> None:
465
+ if config.get("do_normalize", False):
466
+ transforms.append(Normalize(mean=config["image_mean"], std=config["image_std"]))
467
+ elif "mean" in config and "std" in config:
468
+ transforms.append(Normalize(mean=config["mean"], std=config["std"]))
469
+
470
+ @staticmethod
471
+ def _get_pad2square(transforms: list[Transform], config: dict[str, Any]) -> None:
472
+ mode = config.get("image_processor_type", "CLIPImageProcessor")
473
+ if mode == "CLIPImageProcessor":
474
+ pass
475
+ elif mode == "ConvNextFeatureExtractor":
476
+ pass
477
+ elif mode == "JinaCLIPImageProcessor":
478
+ transforms.append(
479
+ PadtoSquare(
480
+ size=config["size"],
481
+ fill_color=config.get("fill_color", 0),
482
+ )
483
+ )
484
+
485
+ @staticmethod
486
+ def _interpolation_resolver(resample: str | None = None) -> Image.Resampling:
487
+ interpolation_map = {
488
+ "nearest": Image.Resampling.NEAREST,
489
+ "lanczos": Image.Resampling.LANCZOS,
490
+ "bilinear": Image.Resampling.BILINEAR,
491
+ "bicubic": Image.Resampling.BICUBIC,
492
+ "box": Image.Resampling.BOX,
493
+ "hamming": Image.Resampling.HAMMING,
494
+ }
495
+
496
+ if resample and (method := interpolation_map.get(resample.lower())):
497
+ return method
498
+
499
+ raise ValueError(f"Unknown interpolation method: {resample}")
@@ -0,0 +1,5 @@
1
+ from fastembed.late_interaction.late_interaction_text_embedding import (
2
+ LateInteractionTextEmbedding,
3
+ )
4
+
5
+ __all__ = ["LateInteractionTextEmbedding"]