fastembed-bio 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastembed/__init__.py +24 -0
- fastembed/bio/__init__.py +3 -0
- fastembed/bio/protein_embedding.py +456 -0
- fastembed/common/__init__.py +3 -0
- fastembed/common/model_description.py +52 -0
- fastembed/common/model_management.py +471 -0
- fastembed/common/onnx_model.py +188 -0
- fastembed/common/preprocessor_utils.py +84 -0
- fastembed/common/types.py +27 -0
- fastembed/common/utils.py +69 -0
- fastembed/embedding.py +24 -0
- fastembed/image/__init__.py +3 -0
- fastembed/image/image_embedding.py +135 -0
- fastembed/image/image_embedding_base.py +55 -0
- fastembed/image/onnx_embedding.py +217 -0
- fastembed/image/onnx_image_model.py +156 -0
- fastembed/image/transform/functional.py +221 -0
- fastembed/image/transform/operators.py +499 -0
- fastembed/late_interaction/__init__.py +5 -0
- fastembed/late_interaction/colbert.py +301 -0
- fastembed/late_interaction/jina_colbert.py +58 -0
- fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
- fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
- fastembed/late_interaction/token_embeddings.py +83 -0
- fastembed/late_interaction_multimodal/__init__.py +5 -0
- fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
- fastembed/late_interaction_multimodal/colpali.py +327 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
- fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
- fastembed/parallel_processor.py +253 -0
- fastembed/postprocess/__init__.py +3 -0
- fastembed/postprocess/muvera.py +362 -0
- fastembed/py.typed +1 -0
- fastembed/rerank/cross_encoder/__init__.py +3 -0
- fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
- fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
- fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
- fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
- fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
- fastembed/sparse/__init__.py +4 -0
- fastembed/sparse/bm25.py +359 -0
- fastembed/sparse/bm42.py +369 -0
- fastembed/sparse/minicoil.py +372 -0
- fastembed/sparse/sparse_embedding_base.py +90 -0
- fastembed/sparse/sparse_text_embedding.py +143 -0
- fastembed/sparse/splade_pp.py +196 -0
- fastembed/sparse/utils/minicoil_encoder.py +146 -0
- fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
- fastembed/sparse/utils/tokenizer.py +120 -0
- fastembed/sparse/utils/vocab_resolver.py +202 -0
- fastembed/text/__init__.py +3 -0
- fastembed/text/clip_embedding.py +56 -0
- fastembed/text/custom_text_embedding.py +97 -0
- fastembed/text/multitask_embedding.py +109 -0
- fastembed/text/onnx_embedding.py +353 -0
- fastembed/text/onnx_text_model.py +180 -0
- fastembed/text/pooled_embedding.py +136 -0
- fastembed/text/pooled_normalized_embedding.py +164 -0
- fastembed/text/text_embedding.py +228 -0
- fastembed/text/text_embedding_base.py +75 -0
- fastembed_bio-0.1.0.dist-info/METADATA +339 -0
- fastembed_bio-0.1.0.dist-info/RECORD +66 -0
- fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
- fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
- fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import os
|
|
3
|
+
from multiprocessing import get_all_start_methods
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Iterable, Sequence, Type
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from PIL import Image
|
|
9
|
+
|
|
10
|
+
from fastembed.image.transform.operators import Compose
|
|
11
|
+
from fastembed.common.types import NumpyArray, Device
|
|
12
|
+
from fastembed.common import ImageInput, OnnxProvider
|
|
13
|
+
from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxOutputContext, T
|
|
14
|
+
from fastembed.common.preprocessor_utils import load_preprocessor
|
|
15
|
+
from fastembed.common.utils import iter_batch
|
|
16
|
+
from fastembed.parallel_processor import ParallelWorkerPool
|
|
17
|
+
|
|
18
|
+
# Holds type of the embedding result
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class OnnxImageModel(OnnxModel[T]):
|
|
22
|
+
@classmethod
|
|
23
|
+
def _get_worker_class(cls) -> Type["ImageEmbeddingWorker[T]"]:
|
|
24
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
25
|
+
|
|
26
|
+
def _post_process_onnx_output(self, output: OnnxOutputContext, **kwargs: Any) -> Iterable[T]:
|
|
27
|
+
"""Post-process the ONNX model output to convert it into a usable format.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
output (OnnxOutputContext): The raw output from the ONNX model.
|
|
31
|
+
**kwargs: Additional keyword arguments that may be needed by specific implementations.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Iterable[T]: Post-processed output as an iterable of type T.
|
|
35
|
+
"""
|
|
36
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
37
|
+
|
|
38
|
+
def __init__(self) -> None:
|
|
39
|
+
super().__init__()
|
|
40
|
+
self.processor: Compose | None = None
|
|
41
|
+
|
|
42
|
+
def _preprocess_onnx_input(
|
|
43
|
+
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
|
|
44
|
+
) -> dict[str, NumpyArray]:
|
|
45
|
+
"""
|
|
46
|
+
Preprocess the onnx input.
|
|
47
|
+
"""
|
|
48
|
+
return onnx_input
|
|
49
|
+
|
|
50
|
+
def _load_onnx_model(
|
|
51
|
+
self,
|
|
52
|
+
model_dir: Path,
|
|
53
|
+
model_file: str,
|
|
54
|
+
threads: int | None,
|
|
55
|
+
providers: Sequence[OnnxProvider] | None = None,
|
|
56
|
+
cuda: bool | Device = Device.AUTO,
|
|
57
|
+
device_id: int | None = None,
|
|
58
|
+
extra_session_options: dict[str, Any] | None = None,
|
|
59
|
+
) -> None:
|
|
60
|
+
super()._load_onnx_model(
|
|
61
|
+
model_dir=model_dir,
|
|
62
|
+
model_file=model_file,
|
|
63
|
+
threads=threads,
|
|
64
|
+
providers=providers,
|
|
65
|
+
cuda=cuda,
|
|
66
|
+
device_id=device_id,
|
|
67
|
+
extra_session_options=extra_session_options,
|
|
68
|
+
)
|
|
69
|
+
self.processor = load_preprocessor(model_dir=model_dir)
|
|
70
|
+
|
|
71
|
+
def load_onnx_model(self) -> None:
|
|
72
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
73
|
+
|
|
74
|
+
def _build_onnx_input(self, encoded: NumpyArray) -> dict[str, NumpyArray]:
|
|
75
|
+
input_name = self.model.get_inputs()[0].name # type: ignore[union-attr]
|
|
76
|
+
return {input_name: encoded}
|
|
77
|
+
|
|
78
|
+
def onnx_embed(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext:
|
|
79
|
+
with contextlib.ExitStack() as stack:
|
|
80
|
+
image_files = [
|
|
81
|
+
stack.enter_context(Image.open(image))
|
|
82
|
+
if not isinstance(image, Image.Image)
|
|
83
|
+
else image
|
|
84
|
+
for image in images
|
|
85
|
+
]
|
|
86
|
+
assert self.processor is not None, "Processor is not initialized"
|
|
87
|
+
encoded = np.array(self.processor(image_files))
|
|
88
|
+
onnx_input = self._build_onnx_input(encoded)
|
|
89
|
+
onnx_input = self._preprocess_onnx_input(onnx_input)
|
|
90
|
+
model_output = self.model.run(None, onnx_input) # type: ignore[union-attr]
|
|
91
|
+
embeddings = model_output[0].reshape(len(images), -1)
|
|
92
|
+
return OnnxOutputContext(model_output=embeddings)
|
|
93
|
+
|
|
94
|
+
def _embed_images(
|
|
95
|
+
self,
|
|
96
|
+
model_name: str,
|
|
97
|
+
cache_dir: str,
|
|
98
|
+
images: ImageInput | Iterable[ImageInput],
|
|
99
|
+
batch_size: int = 256,
|
|
100
|
+
parallel: int | None = None,
|
|
101
|
+
providers: Sequence[OnnxProvider] | None = None,
|
|
102
|
+
cuda: bool | Device = Device.AUTO,
|
|
103
|
+
device_ids: list[int] | None = None,
|
|
104
|
+
local_files_only: bool = False,
|
|
105
|
+
specific_model_path: str | None = None,
|
|
106
|
+
extra_session_options: dict[str, Any] | None = None,
|
|
107
|
+
**kwargs: Any,
|
|
108
|
+
) -> Iterable[T]:
|
|
109
|
+
is_small = False
|
|
110
|
+
|
|
111
|
+
if isinstance(images, (str, Path, Image.Image)):
|
|
112
|
+
images = [images]
|
|
113
|
+
is_small = True
|
|
114
|
+
|
|
115
|
+
if isinstance(images, list) and len(images) < batch_size:
|
|
116
|
+
is_small = True
|
|
117
|
+
|
|
118
|
+
if parallel is None or is_small:
|
|
119
|
+
if not hasattr(self, "model") or self.model is None:
|
|
120
|
+
self.load_onnx_model()
|
|
121
|
+
|
|
122
|
+
for batch in iter_batch(images, batch_size):
|
|
123
|
+
yield from self._post_process_onnx_output(self.onnx_embed(batch), **kwargs)
|
|
124
|
+
else:
|
|
125
|
+
if parallel == 0:
|
|
126
|
+
parallel = os.cpu_count()
|
|
127
|
+
|
|
128
|
+
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
|
|
129
|
+
params = {
|
|
130
|
+
"model_name": model_name,
|
|
131
|
+
"cache_dir": cache_dir,
|
|
132
|
+
"providers": providers,
|
|
133
|
+
"local_files_only": local_files_only,
|
|
134
|
+
"specific_model_path": specific_model_path,
|
|
135
|
+
**kwargs,
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
if extra_session_options is not None:
|
|
139
|
+
params.update(extra_session_options)
|
|
140
|
+
|
|
141
|
+
pool = ParallelWorkerPool(
|
|
142
|
+
num_workers=parallel or 1,
|
|
143
|
+
worker=self._get_worker_class(),
|
|
144
|
+
cuda=cuda,
|
|
145
|
+
device_ids=device_ids,
|
|
146
|
+
start_method=start_method,
|
|
147
|
+
)
|
|
148
|
+
for batch in pool.ordered_map(iter_batch(images, batch_size), **params):
|
|
149
|
+
yield from self._post_process_onnx_output(batch, **kwargs) # type: ignore
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class ImageEmbeddingWorker(EmbeddingWorker[T]):
|
|
153
|
+
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
|
|
154
|
+
for idx, batch in items:
|
|
155
|
+
embeddings = self.model.onnx_embed(batch)
|
|
156
|
+
yield idx, embeddings
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from PIL import Image
|
|
3
|
+
|
|
4
|
+
from fastembed.common.types import NumpyArray
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def convert_to_rgb(image: Image.Image) -> Image.Image:
|
|
8
|
+
if image.mode == "RGB":
|
|
9
|
+
return image
|
|
10
|
+
|
|
11
|
+
image = image.convert("RGB")
|
|
12
|
+
return image
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def center_crop(
|
|
16
|
+
image: Image.Image | NumpyArray,
|
|
17
|
+
size: tuple[int, int],
|
|
18
|
+
) -> NumpyArray:
|
|
19
|
+
if isinstance(image, np.ndarray):
|
|
20
|
+
_, orig_height, orig_width = image.shape
|
|
21
|
+
else:
|
|
22
|
+
orig_height, orig_width = image.height, image.width
|
|
23
|
+
# (H, W, C) -> (C, H, W)
|
|
24
|
+
image = np.array(image).transpose((2, 0, 1))
|
|
25
|
+
|
|
26
|
+
crop_height, crop_width = size
|
|
27
|
+
|
|
28
|
+
# left upper corner (0, 0)
|
|
29
|
+
top = (orig_height - crop_height) // 2
|
|
30
|
+
bottom = top + crop_height
|
|
31
|
+
left = (orig_width - crop_width) // 2
|
|
32
|
+
right = left + crop_width
|
|
33
|
+
|
|
34
|
+
# Check if cropped area is within image boundaries
|
|
35
|
+
if top >= 0 and bottom <= orig_height and left >= 0 and right <= orig_width:
|
|
36
|
+
image = image[..., top:bottom, left:right]
|
|
37
|
+
return image
|
|
38
|
+
|
|
39
|
+
# Padding with zeros
|
|
40
|
+
new_height = max(crop_height, orig_height)
|
|
41
|
+
new_width = max(crop_width, orig_width)
|
|
42
|
+
new_shape = image.shape[:-2] + (new_height, new_width)
|
|
43
|
+
new_image = np.zeros_like(image, shape=new_shape, dtype=np.float32)
|
|
44
|
+
|
|
45
|
+
top_pad = (new_height - orig_height) // 2
|
|
46
|
+
bottom_pad = top_pad + orig_height
|
|
47
|
+
left_pad = (new_width - orig_width) // 2
|
|
48
|
+
right_pad = left_pad + orig_width
|
|
49
|
+
new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image
|
|
50
|
+
|
|
51
|
+
top += top_pad
|
|
52
|
+
bottom += top_pad
|
|
53
|
+
left += left_pad
|
|
54
|
+
right += left_pad
|
|
55
|
+
|
|
56
|
+
new_image = new_image[
|
|
57
|
+
..., max(0, top) : min(new_height, bottom), max(0, left) : min(new_width, right)
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
return new_image
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def normalize(
|
|
64
|
+
image: NumpyArray,
|
|
65
|
+
mean: float | list[float],
|
|
66
|
+
std: float | list[float],
|
|
67
|
+
) -> NumpyArray:
|
|
68
|
+
num_channels = image.shape[1] if len(image.shape) == 4 else image.shape[0]
|
|
69
|
+
|
|
70
|
+
if not np.issubdtype(image.dtype, np.floating):
|
|
71
|
+
image = image.astype(np.float32)
|
|
72
|
+
|
|
73
|
+
mean_list = mean if isinstance(mean, list) else [mean] * num_channels
|
|
74
|
+
|
|
75
|
+
if len(mean_list) != num_channels:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"mean must have the same number of channels as the image, image has {num_channels} channels, got "
|
|
78
|
+
f"{len(mean_list)}"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
mean_arr = np.array(mean_list, dtype=np.float32)
|
|
82
|
+
|
|
83
|
+
std_list = std if isinstance(std, list) else [std] * num_channels
|
|
84
|
+
if len(std_list) != num_channels:
|
|
85
|
+
raise ValueError(
|
|
86
|
+
f"std must have the same number of channels as the image, image has {num_channels} channels, got {len(std_list)}"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
std_arr = np.array(std_list, dtype=np.float32)
|
|
90
|
+
|
|
91
|
+
image_upd = ((image.T - mean_arr) / std_arr).T
|
|
92
|
+
return image_upd
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def resize(
|
|
96
|
+
image: Image.Image,
|
|
97
|
+
size: int | tuple[int, int],
|
|
98
|
+
resample: int | Image.Resampling = Image.Resampling.BILINEAR,
|
|
99
|
+
) -> Image.Image:
|
|
100
|
+
if isinstance(size, tuple):
|
|
101
|
+
return image.resize(size, resample)
|
|
102
|
+
|
|
103
|
+
height, width = image.height, image.width
|
|
104
|
+
short, long = (width, height) if width <= height else (height, width)
|
|
105
|
+
|
|
106
|
+
new_short, new_long = size, int(size * long / short)
|
|
107
|
+
if width <= height:
|
|
108
|
+
new_size = (new_short, new_long)
|
|
109
|
+
else:
|
|
110
|
+
new_size = (new_long, new_short)
|
|
111
|
+
return image.resize(new_size, resample)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def rescale(image: NumpyArray, scale: float, dtype: type = np.float32) -> NumpyArray:
|
|
115
|
+
return (image * scale).astype(dtype)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def pil2ndarray(image: Image.Image | NumpyArray) -> NumpyArray:
|
|
119
|
+
if isinstance(image, Image.Image):
|
|
120
|
+
return np.asarray(image).transpose((2, 0, 1))
|
|
121
|
+
return image
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def pad2square(
|
|
125
|
+
image: Image.Image,
|
|
126
|
+
size: int,
|
|
127
|
+
fill_color: str | int | tuple[int, ...] = 0,
|
|
128
|
+
) -> Image.Image:
|
|
129
|
+
height, width = image.height, image.width
|
|
130
|
+
|
|
131
|
+
left, right = 0, width
|
|
132
|
+
top, bottom = 0, height
|
|
133
|
+
|
|
134
|
+
crop_required = False
|
|
135
|
+
if width > size:
|
|
136
|
+
left = (width - size) // 2
|
|
137
|
+
right = left + size
|
|
138
|
+
crop_required = True
|
|
139
|
+
|
|
140
|
+
if height > size:
|
|
141
|
+
top = (height - size) // 2
|
|
142
|
+
bottom = top + size
|
|
143
|
+
crop_required = True
|
|
144
|
+
|
|
145
|
+
new_image = Image.new(mode="RGB", size=(size, size), color=fill_color)
|
|
146
|
+
new_image.paste(image.crop((left, top, right, bottom)) if crop_required else image)
|
|
147
|
+
return new_image
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def resize_longest_edge(
|
|
151
|
+
image: Image.Image,
|
|
152
|
+
max_size: int,
|
|
153
|
+
resample: int | Image.Resampling = Image.Resampling.LANCZOS,
|
|
154
|
+
) -> Image.Image:
|
|
155
|
+
height, width = image.height, image.width
|
|
156
|
+
aspect_ratio = width / height
|
|
157
|
+
|
|
158
|
+
if width >= height:
|
|
159
|
+
# Width is longer
|
|
160
|
+
new_width = max_size
|
|
161
|
+
new_height = int(new_width / aspect_ratio)
|
|
162
|
+
else:
|
|
163
|
+
# Height is longer
|
|
164
|
+
new_height = max_size
|
|
165
|
+
new_width = int(new_height * aspect_ratio)
|
|
166
|
+
|
|
167
|
+
# Ensure even dimensions
|
|
168
|
+
if new_height % 2 != 0:
|
|
169
|
+
new_height += 1
|
|
170
|
+
if new_width % 2 != 0:
|
|
171
|
+
new_width += 1
|
|
172
|
+
|
|
173
|
+
return image.resize((new_width, new_height), resample)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def crop_ndarray(
|
|
177
|
+
image: NumpyArray,
|
|
178
|
+
x1: int,
|
|
179
|
+
y1: int,
|
|
180
|
+
x2: int,
|
|
181
|
+
y2: int,
|
|
182
|
+
channel_first: bool = True,
|
|
183
|
+
) -> NumpyArray:
|
|
184
|
+
if channel_first:
|
|
185
|
+
# (C, H, W) format
|
|
186
|
+
return image[:, y1:y2, x1:x2]
|
|
187
|
+
else:
|
|
188
|
+
# (H, W, C) format
|
|
189
|
+
return image[y1:y2, x1:x2, :]
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def resize_ndarray(
|
|
193
|
+
image: NumpyArray,
|
|
194
|
+
size: tuple[int, int],
|
|
195
|
+
resample: int | Image.Resampling = Image.Resampling.LANCZOS,
|
|
196
|
+
channel_first: bool = True,
|
|
197
|
+
) -> NumpyArray:
|
|
198
|
+
# Convert to PIL-friendly format (H, W, C)
|
|
199
|
+
if channel_first:
|
|
200
|
+
img_hwc = image.transpose((1, 2, 0))
|
|
201
|
+
else:
|
|
202
|
+
img_hwc = image
|
|
203
|
+
|
|
204
|
+
# Handle different dtypes
|
|
205
|
+
if img_hwc.dtype == np.float32 or img_hwc.dtype == np.float64:
|
|
206
|
+
# Assume normalized, scale to 0-255 for PIL
|
|
207
|
+
img_hwc_scaled = (img_hwc * 255).astype(np.uint8)
|
|
208
|
+
pil_img = Image.fromarray(img_hwc_scaled, mode="RGB")
|
|
209
|
+
resized = pil_img.resize(size, resample)
|
|
210
|
+
result = np.array(resized).astype(np.float32) / 255.0
|
|
211
|
+
else:
|
|
212
|
+
# uint8 or similar
|
|
213
|
+
pil_img = Image.fromarray(img_hwc.astype(np.uint8), mode="RGB")
|
|
214
|
+
resized = pil_img.resize(size, resample)
|
|
215
|
+
result = np.array(resized)
|
|
216
|
+
|
|
217
|
+
# Convert back to original format
|
|
218
|
+
if channel_first:
|
|
219
|
+
result = result.transpose((2, 0, 1))
|
|
220
|
+
|
|
221
|
+
return result
|