fastembed-bio 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. fastembed/__init__.py +24 -0
  2. fastembed/bio/__init__.py +3 -0
  3. fastembed/bio/protein_embedding.py +456 -0
  4. fastembed/common/__init__.py +3 -0
  5. fastembed/common/model_description.py +52 -0
  6. fastembed/common/model_management.py +471 -0
  7. fastembed/common/onnx_model.py +188 -0
  8. fastembed/common/preprocessor_utils.py +84 -0
  9. fastembed/common/types.py +27 -0
  10. fastembed/common/utils.py +69 -0
  11. fastembed/embedding.py +24 -0
  12. fastembed/image/__init__.py +3 -0
  13. fastembed/image/image_embedding.py +135 -0
  14. fastembed/image/image_embedding_base.py +55 -0
  15. fastembed/image/onnx_embedding.py +217 -0
  16. fastembed/image/onnx_image_model.py +156 -0
  17. fastembed/image/transform/functional.py +221 -0
  18. fastembed/image/transform/operators.py +499 -0
  19. fastembed/late_interaction/__init__.py +5 -0
  20. fastembed/late_interaction/colbert.py +301 -0
  21. fastembed/late_interaction/jina_colbert.py +58 -0
  22. fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
  23. fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
  24. fastembed/late_interaction/token_embeddings.py +83 -0
  25. fastembed/late_interaction_multimodal/__init__.py +5 -0
  26. fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
  27. fastembed/late_interaction_multimodal/colpali.py +327 -0
  28. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
  29. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
  30. fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
  31. fastembed/parallel_processor.py +253 -0
  32. fastembed/postprocess/__init__.py +3 -0
  33. fastembed/postprocess/muvera.py +362 -0
  34. fastembed/py.typed +1 -0
  35. fastembed/rerank/cross_encoder/__init__.py +3 -0
  36. fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
  37. fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
  38. fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
  39. fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
  40. fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
  41. fastembed/sparse/__init__.py +4 -0
  42. fastembed/sparse/bm25.py +359 -0
  43. fastembed/sparse/bm42.py +369 -0
  44. fastembed/sparse/minicoil.py +372 -0
  45. fastembed/sparse/sparse_embedding_base.py +90 -0
  46. fastembed/sparse/sparse_text_embedding.py +143 -0
  47. fastembed/sparse/splade_pp.py +196 -0
  48. fastembed/sparse/utils/minicoil_encoder.py +146 -0
  49. fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
  50. fastembed/sparse/utils/tokenizer.py +120 -0
  51. fastembed/sparse/utils/vocab_resolver.py +202 -0
  52. fastembed/text/__init__.py +3 -0
  53. fastembed/text/clip_embedding.py +56 -0
  54. fastembed/text/custom_text_embedding.py +97 -0
  55. fastembed/text/multitask_embedding.py +109 -0
  56. fastembed/text/onnx_embedding.py +353 -0
  57. fastembed/text/onnx_text_model.py +180 -0
  58. fastembed/text/pooled_embedding.py +136 -0
  59. fastembed/text/pooled_normalized_embedding.py +164 -0
  60. fastembed/text/text_embedding.py +228 -0
  61. fastembed/text/text_embedding_base.py +75 -0
  62. fastembed_bio-0.1.0.dist-info/METADATA +339 -0
  63. fastembed_bio-0.1.0.dist-info/RECORD +66 -0
  64. fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
  65. fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
  66. fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
@@ -0,0 +1,156 @@
1
+ import contextlib
2
+ import os
3
+ from multiprocessing import get_all_start_methods
4
+ from pathlib import Path
5
+ from typing import Any, Iterable, Sequence, Type
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ from fastembed.image.transform.operators import Compose
11
+ from fastembed.common.types import NumpyArray, Device
12
+ from fastembed.common import ImageInput, OnnxProvider
13
+ from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxOutputContext, T
14
+ from fastembed.common.preprocessor_utils import load_preprocessor
15
+ from fastembed.common.utils import iter_batch
16
+ from fastembed.parallel_processor import ParallelWorkerPool
17
+
18
+ # Holds type of the embedding result
19
+
20
+
21
+ class OnnxImageModel(OnnxModel[T]):
22
+ @classmethod
23
+ def _get_worker_class(cls) -> Type["ImageEmbeddingWorker[T]"]:
24
+ raise NotImplementedError("Subclasses must implement this method")
25
+
26
+ def _post_process_onnx_output(self, output: OnnxOutputContext, **kwargs: Any) -> Iterable[T]:
27
+ """Post-process the ONNX model output to convert it into a usable format.
28
+
29
+ Args:
30
+ output (OnnxOutputContext): The raw output from the ONNX model.
31
+ **kwargs: Additional keyword arguments that may be needed by specific implementations.
32
+
33
+ Returns:
34
+ Iterable[T]: Post-processed output as an iterable of type T.
35
+ """
36
+ raise NotImplementedError("Subclasses must implement this method")
37
+
38
+ def __init__(self) -> None:
39
+ super().__init__()
40
+ self.processor: Compose | None = None
41
+
42
+ def _preprocess_onnx_input(
43
+ self, onnx_input: dict[str, NumpyArray], **kwargs: Any
44
+ ) -> dict[str, NumpyArray]:
45
+ """
46
+ Preprocess the onnx input.
47
+ """
48
+ return onnx_input
49
+
50
+ def _load_onnx_model(
51
+ self,
52
+ model_dir: Path,
53
+ model_file: str,
54
+ threads: int | None,
55
+ providers: Sequence[OnnxProvider] | None = None,
56
+ cuda: bool | Device = Device.AUTO,
57
+ device_id: int | None = None,
58
+ extra_session_options: dict[str, Any] | None = None,
59
+ ) -> None:
60
+ super()._load_onnx_model(
61
+ model_dir=model_dir,
62
+ model_file=model_file,
63
+ threads=threads,
64
+ providers=providers,
65
+ cuda=cuda,
66
+ device_id=device_id,
67
+ extra_session_options=extra_session_options,
68
+ )
69
+ self.processor = load_preprocessor(model_dir=model_dir)
70
+
71
+ def load_onnx_model(self) -> None:
72
+ raise NotImplementedError("Subclasses must implement this method")
73
+
74
+ def _build_onnx_input(self, encoded: NumpyArray) -> dict[str, NumpyArray]:
75
+ input_name = self.model.get_inputs()[0].name # type: ignore[union-attr]
76
+ return {input_name: encoded}
77
+
78
+ def onnx_embed(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext:
79
+ with contextlib.ExitStack() as stack:
80
+ image_files = [
81
+ stack.enter_context(Image.open(image))
82
+ if not isinstance(image, Image.Image)
83
+ else image
84
+ for image in images
85
+ ]
86
+ assert self.processor is not None, "Processor is not initialized"
87
+ encoded = np.array(self.processor(image_files))
88
+ onnx_input = self._build_onnx_input(encoded)
89
+ onnx_input = self._preprocess_onnx_input(onnx_input)
90
+ model_output = self.model.run(None, onnx_input) # type: ignore[union-attr]
91
+ embeddings = model_output[0].reshape(len(images), -1)
92
+ return OnnxOutputContext(model_output=embeddings)
93
+
94
+ def _embed_images(
95
+ self,
96
+ model_name: str,
97
+ cache_dir: str,
98
+ images: ImageInput | Iterable[ImageInput],
99
+ batch_size: int = 256,
100
+ parallel: int | None = None,
101
+ providers: Sequence[OnnxProvider] | None = None,
102
+ cuda: bool | Device = Device.AUTO,
103
+ device_ids: list[int] | None = None,
104
+ local_files_only: bool = False,
105
+ specific_model_path: str | None = None,
106
+ extra_session_options: dict[str, Any] | None = None,
107
+ **kwargs: Any,
108
+ ) -> Iterable[T]:
109
+ is_small = False
110
+
111
+ if isinstance(images, (str, Path, Image.Image)):
112
+ images = [images]
113
+ is_small = True
114
+
115
+ if isinstance(images, list) and len(images) < batch_size:
116
+ is_small = True
117
+
118
+ if parallel is None or is_small:
119
+ if not hasattr(self, "model") or self.model is None:
120
+ self.load_onnx_model()
121
+
122
+ for batch in iter_batch(images, batch_size):
123
+ yield from self._post_process_onnx_output(self.onnx_embed(batch), **kwargs)
124
+ else:
125
+ if parallel == 0:
126
+ parallel = os.cpu_count()
127
+
128
+ start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
129
+ params = {
130
+ "model_name": model_name,
131
+ "cache_dir": cache_dir,
132
+ "providers": providers,
133
+ "local_files_only": local_files_only,
134
+ "specific_model_path": specific_model_path,
135
+ **kwargs,
136
+ }
137
+
138
+ if extra_session_options is not None:
139
+ params.update(extra_session_options)
140
+
141
+ pool = ParallelWorkerPool(
142
+ num_workers=parallel or 1,
143
+ worker=self._get_worker_class(),
144
+ cuda=cuda,
145
+ device_ids=device_ids,
146
+ start_method=start_method,
147
+ )
148
+ for batch in pool.ordered_map(iter_batch(images, batch_size), **params):
149
+ yield from self._post_process_onnx_output(batch, **kwargs) # type: ignore
150
+
151
+
152
+ class ImageEmbeddingWorker(EmbeddingWorker[T]):
153
+ def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
154
+ for idx, batch in items:
155
+ embeddings = self.model.onnx_embed(batch)
156
+ yield idx, embeddings
@@ -0,0 +1,221 @@
1
+ import numpy as np
2
+ from PIL import Image
3
+
4
+ from fastembed.common.types import NumpyArray
5
+
6
+
7
+ def convert_to_rgb(image: Image.Image) -> Image.Image:
8
+ if image.mode == "RGB":
9
+ return image
10
+
11
+ image = image.convert("RGB")
12
+ return image
13
+
14
+
15
+ def center_crop(
16
+ image: Image.Image | NumpyArray,
17
+ size: tuple[int, int],
18
+ ) -> NumpyArray:
19
+ if isinstance(image, np.ndarray):
20
+ _, orig_height, orig_width = image.shape
21
+ else:
22
+ orig_height, orig_width = image.height, image.width
23
+ # (H, W, C) -> (C, H, W)
24
+ image = np.array(image).transpose((2, 0, 1))
25
+
26
+ crop_height, crop_width = size
27
+
28
+ # left upper corner (0, 0)
29
+ top = (orig_height - crop_height) // 2
30
+ bottom = top + crop_height
31
+ left = (orig_width - crop_width) // 2
32
+ right = left + crop_width
33
+
34
+ # Check if cropped area is within image boundaries
35
+ if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width:
36
+ image = image[..., top:bottom, left:right]
37
+ return image
38
+
39
+ # Padding with zeros
40
+ new_height = max(crop_height, orig_height)
41
+ new_width = max(crop_width, orig_width)
42
+ new_shape = image.shape[:-2] + (new_height, new_width)
43
+ new_image = np.zeros_like(image, shape=new_shape, dtype=np.float32)
44
+
45
+ top_pad = (new_height - orig_height) // 2
46
+ bottom_pad = top_pad + orig_height
47
+ left_pad = (new_width - orig_width) // 2
48
+ right_pad = left_pad + orig_width
49
+ new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
50
+
51
+ top += top_pad
52
+ bottom += top_pad
53
+ left += left_pad
54
+ right += left_pad
55
+
56
+ new_image = new_image[
57
+ ..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)
58
+ ]
59
+
60
+ return new_image
61
+
62
+
63
+ def normalize(
64
+ image: NumpyArray,
65
+ mean: float | list[float],
66
+ std: float | list[float],
67
+ ) -> NumpyArray:
68
+ num_channels = image.shape[1] if len(image.shape) == 4 else image.shape[0]
69
+
70
+ if not np.issubdtype(image.dtype, np.floating):
71
+ image = image.astype(np.float32)
72
+
73
+ mean_list = mean if isinstance(mean, list) else [mean] * num_channels
74
+
75
+ if len(mean_list) != num_channels:
76
+ raise ValueError(
77
+ f"mean must have the same number of channels as the image, image has {num_channels} channels, got "
78
+ f"{len(mean_list)}"
79
+ )
80
+
81
+ mean_arr = np.array(mean_list, dtype=np.float32)
82
+
83
+ std_list = std if isinstance(std, list) else [std] * num_channels
84
+ if len(std_list) != num_channels:
85
+ raise ValueError(
86
+ f"std must have the same number of channels as the image, image has {num_channels} channels, got {len(std_list)}"
87
+ )
88
+
89
+ std_arr = np.array(std_list, dtype=np.float32)
90
+
91
+ image_upd = ((image.T - mean_arr) / std_arr).T
92
+ return image_upd
93
+
94
+
95
+ def resize(
96
+ image: Image.Image,
97
+ size: int | tuple[int, int],
98
+ resample: int | Image.Resampling = Image.Resampling.BILINEAR,
99
+ ) -> Image.Image:
100
+ if isinstance(size, tuple):
101
+ return image.resize(size, resample)
102
+
103
+ height, width = image.height, image.width
104
+ short, long = (width, height) if width <= height else (height, width)
105
+
106
+ new_short, new_long = size, int(size * long / short)
107
+ if width <= height:
108
+ new_size = (new_short, new_long)
109
+ else:
110
+ new_size = (new_long, new_short)
111
+ return image.resize(new_size, resample)
112
+
113
+
114
+ def rescale(image: NumpyArray, scale: float, dtype: type = np.float32) -> NumpyArray:
115
+ return (image * scale).astype(dtype)
116
+
117
+
118
+ def pil2ndarray(image: Image.Image | NumpyArray) -> NumpyArray:
119
+ if isinstance(image, Image.Image):
120
+ return np.asarray(image).transpose((2, 0, 1))
121
+ return image
122
+
123
+
124
+ def pad2square(
125
+ image: Image.Image,
126
+ size: int,
127
+ fill_color: str | int | tuple[int, ...] = 0,
128
+ ) -> Image.Image:
129
+ height, width = image.height, image.width
130
+
131
+ left, right = 0, width
132
+ top, bottom = 0, height
133
+
134
+ crop_required = False
135
+ if width > size:
136
+ left = (width - size) // 2
137
+ right = left + size
138
+ crop_required = True
139
+
140
+ if height > size:
141
+ top = (height - size) // 2
142
+ bottom = top + size
143
+ crop_required = True
144
+
145
+ new_image = Image.new(mode="RGB", size=(size, size), color=fill_color)
146
+ new_image.paste(image.crop((left, top, right, bottom)) if crop_required else image)
147
+ return new_image
148
+
149
+
150
+ def resize_longest_edge(
151
+ image: Image.Image,
152
+ max_size: int,
153
+ resample: int | Image.Resampling = Image.Resampling.LANCZOS,
154
+ ) -> Image.Image:
155
+ height, width = image.height, image.width
156
+ aspect_ratio = width / height
157
+
158
+ if width >= height:
159
+ # Width is longer
160
+ new_width = max_size
161
+ new_height = int(new_width / aspect_ratio)
162
+ else:
163
+ # Height is longer
164
+ new_height = max_size
165
+ new_width = int(new_height * aspect_ratio)
166
+
167
+ # Ensure even dimensions
168
+ if new_height % 2 != 0:
169
+ new_height += 1
170
+ if new_width % 2 != 0:
171
+ new_width += 1
172
+
173
+ return image.resize((new_width, new_height), resample)
174
+
175
+
176
+ def crop_ndarray(
177
+ image: NumpyArray,
178
+ x1: int,
179
+ y1: int,
180
+ x2: int,
181
+ y2: int,
182
+ channel_first: bool = True,
183
+ ) -> NumpyArray:
184
+ if channel_first:
185
+ # (C, H, W) format
186
+ return image[:, y1:y2, x1:x2]
187
+ else:
188
+ # (H, W, C) format
189
+ return image[y1:y2, x1:x2, :]
190
+
191
+
192
+ def resize_ndarray(
193
+ image: NumpyArray,
194
+ size: tuple[int, int],
195
+ resample: int | Image.Resampling = Image.Resampling.LANCZOS,
196
+ channel_first: bool = True,
197
+ ) -> NumpyArray:
198
+ # Convert to PIL-friendly format (H, W, C)
199
+ if channel_first:
200
+ img_hwc = image.transpose((1, 2, 0))
201
+ else:
202
+ img_hwc = image
203
+
204
+ # Handle different dtypes
205
+ if img_hwc.dtype == np.float32 or img_hwc.dtype == np.float64:
206
+ # Assume normalized, scale to 0-255 for PIL
207
+ img_hwc_scaled = (img_hwc * 255).astype(np.uint8)
208
+ pil_img = Image.fromarray(img_hwc_scaled, mode="RGB")
209
+ resized = pil_img.resize(size, resample)
210
+ result = np.array(resized).astype(np.float32) / 255.0
211
+ else:
212
+ # uint8 or similar
213
+ pil_img = Image.fromarray(img_hwc.astype(np.uint8), mode="RGB")
214
+ resized = pil_img.resize(size, resample)
215
+ result = np.array(resized)
216
+
217
+ # Convert back to original format
218
+ if channel_first:
219
+ result = result.transpose((2, 0, 1))
220
+
221
+ return result