fastembed-bio 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastembed/__init__.py +24 -0
- fastembed/bio/__init__.py +3 -0
- fastembed/bio/protein_embedding.py +456 -0
- fastembed/common/__init__.py +3 -0
- fastembed/common/model_description.py +52 -0
- fastembed/common/model_management.py +471 -0
- fastembed/common/onnx_model.py +188 -0
- fastembed/common/preprocessor_utils.py +84 -0
- fastembed/common/types.py +27 -0
- fastembed/common/utils.py +69 -0
- fastembed/embedding.py +24 -0
- fastembed/image/__init__.py +3 -0
- fastembed/image/image_embedding.py +135 -0
- fastembed/image/image_embedding_base.py +55 -0
- fastembed/image/onnx_embedding.py +217 -0
- fastembed/image/onnx_image_model.py +156 -0
- fastembed/image/transform/functional.py +221 -0
- fastembed/image/transform/operators.py +499 -0
- fastembed/late_interaction/__init__.py +5 -0
- fastembed/late_interaction/colbert.py +301 -0
- fastembed/late_interaction/jina_colbert.py +58 -0
- fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
- fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
- fastembed/late_interaction/token_embeddings.py +83 -0
- fastembed/late_interaction_multimodal/__init__.py +5 -0
- fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
- fastembed/late_interaction_multimodal/colpali.py +327 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
- fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
- fastembed/parallel_processor.py +253 -0
- fastembed/postprocess/__init__.py +3 -0
- fastembed/postprocess/muvera.py +362 -0
- fastembed/py.typed +1 -0
- fastembed/rerank/cross_encoder/__init__.py +3 -0
- fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
- fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
- fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
- fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
- fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
- fastembed/sparse/__init__.py +4 -0
- fastembed/sparse/bm25.py +359 -0
- fastembed/sparse/bm42.py +369 -0
- fastembed/sparse/minicoil.py +372 -0
- fastembed/sparse/sparse_embedding_base.py +90 -0
- fastembed/sparse/sparse_text_embedding.py +143 -0
- fastembed/sparse/splade_pp.py +196 -0
- fastembed/sparse/utils/minicoil_encoder.py +146 -0
- fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
- fastembed/sparse/utils/tokenizer.py +120 -0
- fastembed/sparse/utils/vocab_resolver.py +202 -0
- fastembed/text/__init__.py +3 -0
- fastembed/text/clip_embedding.py +56 -0
- fastembed/text/custom_text_embedding.py +97 -0
- fastembed/text/multitask_embedding.py +109 -0
- fastembed/text/onnx_embedding.py +353 -0
- fastembed/text/onnx_text_model.py +180 -0
- fastembed/text/pooled_embedding.py +136 -0
- fastembed/text/pooled_normalized_embedding.py +164 -0
- fastembed/text/text_embedding.py +228 -0
- fastembed/text/text_embedding_base.py +75 -0
- fastembed_bio-0.1.0.dist-info/METADATA +339 -0
- fastembed_bio-0.1.0.dist-info/RECORD +66 -0
- fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
- fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
- fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from typing import Any
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from tokenizers import AddedToken, Tokenizer
|
|
6
|
+
|
|
7
|
+
from fastembed.image.transform.operators import Compose
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def load_special_tokens(model_dir: Path) -> dict[str, Any]:
|
|
11
|
+
tokens_map_path = model_dir / "special_tokens_map.json"
|
|
12
|
+
if not tokens_map_path.exists():
|
|
13
|
+
raise ValueError(f"Could not find special_tokens_map.json in {model_dir}")
|
|
14
|
+
|
|
15
|
+
with open(str(tokens_map_path)) as tokens_map_file:
|
|
16
|
+
tokens_map = json.load(tokens_map_file)
|
|
17
|
+
|
|
18
|
+
return tokens_map
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def load_tokenizer(model_dir: Path) -> tuple[Tokenizer, dict[str, int]]:
|
|
22
|
+
config_path = model_dir / "config.json"
|
|
23
|
+
if not config_path.exists():
|
|
24
|
+
raise ValueError(f"Could not find config.json in {model_dir}")
|
|
25
|
+
|
|
26
|
+
tokenizer_path = model_dir / "tokenizer.json"
|
|
27
|
+
if not tokenizer_path.exists():
|
|
28
|
+
raise ValueError(f"Could not find tokenizer.json in {model_dir}")
|
|
29
|
+
|
|
30
|
+
tokenizer_config_path = model_dir / "tokenizer_config.json"
|
|
31
|
+
if not tokenizer_config_path.exists():
|
|
32
|
+
raise ValueError(f"Could not find tokenizer_config.json in {model_dir}")
|
|
33
|
+
|
|
34
|
+
with open(str(config_path)) as config_file:
|
|
35
|
+
config = json.load(config_file)
|
|
36
|
+
|
|
37
|
+
with open(str(tokenizer_config_path)) as tokenizer_config_file:
|
|
38
|
+
tokenizer_config = json.load(tokenizer_config_file)
|
|
39
|
+
assert "model_max_length" in tokenizer_config or "max_length" in tokenizer_config, (
|
|
40
|
+
"Models without model_max_length or max_length are not supported."
|
|
41
|
+
)
|
|
42
|
+
if "model_max_length" not in tokenizer_config:
|
|
43
|
+
max_context = tokenizer_config["max_length"]
|
|
44
|
+
elif "max_length" not in tokenizer_config:
|
|
45
|
+
max_context = tokenizer_config["model_max_length"]
|
|
46
|
+
else:
|
|
47
|
+
max_context = min(tokenizer_config["model_max_length"], tokenizer_config["max_length"])
|
|
48
|
+
|
|
49
|
+
tokens_map = load_special_tokens(model_dir)
|
|
50
|
+
|
|
51
|
+
tokenizer = Tokenizer.from_file(str(tokenizer_path))
|
|
52
|
+
tokenizer.enable_truncation(max_length=max_context)
|
|
53
|
+
if not tokenizer.padding:
|
|
54
|
+
tokenizer.enable_padding(
|
|
55
|
+
pad_id=config.get("pad_token_id", 0), pad_token=tokenizer_config["pad_token"]
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
for token in tokens_map.values():
|
|
59
|
+
if isinstance(token, str):
|
|
60
|
+
tokenizer.add_special_tokens([token])
|
|
61
|
+
elif isinstance(token, dict):
|
|
62
|
+
tokenizer.add_special_tokens([AddedToken(**token)])
|
|
63
|
+
|
|
64
|
+
special_token_to_id: dict[str, int] = {}
|
|
65
|
+
|
|
66
|
+
for token in tokens_map.values():
|
|
67
|
+
if isinstance(token, str):
|
|
68
|
+
special_token_to_id[token] = tokenizer.token_to_id(token)
|
|
69
|
+
elif isinstance(token, dict):
|
|
70
|
+
token_str = token.get("content", "")
|
|
71
|
+
special_token_to_id[token_str] = tokenizer.token_to_id(token_str)
|
|
72
|
+
|
|
73
|
+
return tokenizer, special_token_to_id
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def load_preprocessor(model_dir: Path) -> Compose:
|
|
77
|
+
preprocessor_config_path = model_dir / "preprocessor_config.json"
|
|
78
|
+
if not preprocessor_config_path.exists():
|
|
79
|
+
raise ValueError(f"Could not find preprocessor_config.json in {model_dir}")
|
|
80
|
+
|
|
81
|
+
with open(str(preprocessor_config_path)) as preprocessor_config_file:
|
|
82
|
+
preprocessor_config = json.load(preprocessor_config_file)
|
|
83
|
+
transforms = Compose.from_config(preprocessor_config)
|
|
84
|
+
return transforms
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any, TypeAlias
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
from numpy.typing import NDArray
|
|
7
|
+
from PIL import Image
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Device(str, Enum):
|
|
11
|
+
CPU = "cpu"
|
|
12
|
+
CUDA = "cuda"
|
|
13
|
+
AUTO = "auto"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
PathInput: TypeAlias = str | Path
|
|
17
|
+
ImageInput: TypeAlias = PathInput | Image.Image
|
|
18
|
+
|
|
19
|
+
OnnxProvider: TypeAlias = str | tuple[str, dict[Any, Any]]
|
|
20
|
+
NumpyArray: TypeAlias = (
|
|
21
|
+
NDArray[np.float64]
|
|
22
|
+
| NDArray[np.float32]
|
|
23
|
+
| NDArray[np.float16]
|
|
24
|
+
| NDArray[np.int8]
|
|
25
|
+
| NDArray[np.int64]
|
|
26
|
+
| NDArray[np.int32]
|
|
27
|
+
)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import re
|
|
4
|
+
import tempfile
|
|
5
|
+
import unicodedata
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from itertools import islice
|
|
8
|
+
from typing import Iterable, TypeVar
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from numpy.typing import NDArray
|
|
12
|
+
|
|
13
|
+
from fastembed.common.types import NumpyArray
|
|
14
|
+
|
|
15
|
+
T = TypeVar("T")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def normalize(input_array: NumpyArray, p: int = 2, dim: int = 1, eps: float = 1e-12) -> NumpyArray:
|
|
19
|
+
# Calculate the Lp norm along the specified dimension
|
|
20
|
+
norm = np.linalg.norm(input_array, ord=p, axis=dim, keepdims=True)
|
|
21
|
+
norm = np.maximum(norm, eps) # Avoid division by zero
|
|
22
|
+
normalized_array = input_array / norm
|
|
23
|
+
return normalized_array
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def mean_pooling(input_array: NumpyArray, attention_mask: NDArray[np.int64]) -> NumpyArray:
|
|
27
|
+
input_mask_expanded = np.expand_dims(attention_mask, axis=-1).astype(np.int64)
|
|
28
|
+
input_mask_expanded = np.tile(input_mask_expanded, (1, 1, input_array.shape[-1]))
|
|
29
|
+
sum_embeddings = np.sum(input_array * input_mask_expanded, axis=1)
|
|
30
|
+
sum_mask = np.sum(input_mask_expanded, axis=1)
|
|
31
|
+
pooled_embeddings = sum_embeddings / np.maximum(sum_mask, 1e-9)
|
|
32
|
+
return pooled_embeddings
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def iter_batch(iterable: Iterable[T], size: int) -> Iterable[list[T]]:
|
|
36
|
+
"""
|
|
37
|
+
>>> list(iter_batch([1,2,3,4,5], 3))
|
|
38
|
+
[[1, 2, 3], [4, 5]]
|
|
39
|
+
"""
|
|
40
|
+
source_iter = iter(iterable)
|
|
41
|
+
while source_iter:
|
|
42
|
+
b = list(islice(source_iter, size))
|
|
43
|
+
if len(b) == 0:
|
|
44
|
+
break
|
|
45
|
+
yield b
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def define_cache_dir(cache_dir: str | None = None) -> Path:
|
|
49
|
+
"""
|
|
50
|
+
Define the cache directory for fastembed
|
|
51
|
+
"""
|
|
52
|
+
if cache_dir is None:
|
|
53
|
+
default_cache_dir = os.path.join(tempfile.gettempdir(), "fastembed_cache")
|
|
54
|
+
cache_path = Path(os.getenv("FASTEMBED_CACHE_PATH", default_cache_dir))
|
|
55
|
+
else:
|
|
56
|
+
cache_path = Path(cache_dir)
|
|
57
|
+
cache_path.mkdir(parents=True, exist_ok=True)
|
|
58
|
+
|
|
59
|
+
return cache_path
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def get_all_punctuation() -> set[str]:
|
|
63
|
+
return set(
|
|
64
|
+
chr(i) for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P")
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def remove_non_alphanumeric(text: str) -> str:
|
|
69
|
+
return re.sub(r"[^\w\s]", " ", text, flags=re.UNICODE)
|
fastembed/embedding.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
from loguru import logger
|
|
4
|
+
|
|
5
|
+
from fastembed import TextEmbedding
|
|
6
|
+
|
|
7
|
+
logger.warning(
|
|
8
|
+
"DefaultEmbedding, FlagEmbedding, JinaEmbedding are deprecated."
|
|
9
|
+
"Use from fastembed import TextEmbedding instead."
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
DefaultEmbedding = TextEmbedding
|
|
13
|
+
FlagEmbedding = TextEmbedding
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class JinaEmbedding(TextEmbedding):
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
model_name: str = "jinaai/jina-embeddings-v2-base-en",
|
|
20
|
+
cache_dir: str | None = None,
|
|
21
|
+
threads: int | None = None,
|
|
22
|
+
**kwargs: Any,
|
|
23
|
+
):
|
|
24
|
+
super().__init__(model_name, cache_dir, threads, **kwargs)
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
from typing import Any, Iterable, Sequence, Type
|
|
2
|
+
from dataclasses import asdict
|
|
3
|
+
|
|
4
|
+
from fastembed.common.types import NumpyArray, Device
|
|
5
|
+
from fastembed.common import ImageInput, OnnxProvider
|
|
6
|
+
from fastembed.image.image_embedding_base import ImageEmbeddingBase
|
|
7
|
+
from fastembed.image.onnx_embedding import OnnxImageEmbedding
|
|
8
|
+
from fastembed.common.model_description import DenseModelDescription
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ImageEmbedding(ImageEmbeddingBase):
|
|
12
|
+
EMBEDDINGS_REGISTRY: list[Type[ImageEmbeddingBase]] = [OnnxImageEmbedding]
|
|
13
|
+
|
|
14
|
+
@classmethod
|
|
15
|
+
def list_supported_models(cls) -> list[dict[str, Any]]:
|
|
16
|
+
"""
|
|
17
|
+
Lists the supported models.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
list[dict[str, Any]]: A list of dictionaries containing the model information.
|
|
21
|
+
|
|
22
|
+
Example:
|
|
23
|
+
```
|
|
24
|
+
[
|
|
25
|
+
{
|
|
26
|
+
"model": "Qdrant/clip-ViT-B-32-vision",
|
|
27
|
+
"dim": 512,
|
|
28
|
+
"description": "CLIP vision encoder based on ViT-B/32",
|
|
29
|
+
"license": "mit",
|
|
30
|
+
"size_in_GB": 0.33,
|
|
31
|
+
"sources": {
|
|
32
|
+
"hf": "Qdrant/clip-ViT-B-32-vision",
|
|
33
|
+
},
|
|
34
|
+
"model_file": "model.onnx",
|
|
35
|
+
}
|
|
36
|
+
]
|
|
37
|
+
```
|
|
38
|
+
"""
|
|
39
|
+
return [asdict(model) for model in cls._list_supported_models()]
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
def _list_supported_models(cls) -> list[DenseModelDescription]:
|
|
43
|
+
result: list[DenseModelDescription] = []
|
|
44
|
+
for embedding in cls.EMBEDDINGS_REGISTRY:
|
|
45
|
+
result.extend(embedding._list_supported_models())
|
|
46
|
+
return result
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
model_name: str,
|
|
51
|
+
cache_dir: str | None = None,
|
|
52
|
+
threads: int | None = None,
|
|
53
|
+
providers: Sequence[OnnxProvider] | None = None,
|
|
54
|
+
cuda: bool | Device = Device.AUTO,
|
|
55
|
+
device_ids: list[int] | None = None,
|
|
56
|
+
lazy_load: bool = False,
|
|
57
|
+
**kwargs: Any,
|
|
58
|
+
):
|
|
59
|
+
super().__init__(model_name, cache_dir, threads, **kwargs)
|
|
60
|
+
for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY:
|
|
61
|
+
supported_models = EMBEDDING_MODEL_TYPE._list_supported_models()
|
|
62
|
+
if any(model_name.lower() == model.model.lower() for model in supported_models):
|
|
63
|
+
self.model = EMBEDDING_MODEL_TYPE(
|
|
64
|
+
model_name,
|
|
65
|
+
cache_dir,
|
|
66
|
+
threads=threads,
|
|
67
|
+
providers=providers,
|
|
68
|
+
cuda=cuda,
|
|
69
|
+
device_ids=device_ids,
|
|
70
|
+
lazy_load=lazy_load,
|
|
71
|
+
**kwargs,
|
|
72
|
+
)
|
|
73
|
+
return
|
|
74
|
+
|
|
75
|
+
raise ValueError(
|
|
76
|
+
f"Model {model_name} is not supported in ImageEmbedding."
|
|
77
|
+
"Please check the supported models using `ImageEmbedding.list_supported_models()`"
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def embedding_size(self) -> int:
|
|
82
|
+
"""Get the embedding size of the current model"""
|
|
83
|
+
if self._embedding_size is None:
|
|
84
|
+
self._embedding_size = self.get_embedding_size(self.model_name)
|
|
85
|
+
return self._embedding_size
|
|
86
|
+
|
|
87
|
+
@classmethod
|
|
88
|
+
def get_embedding_size(cls, model_name: str) -> int:
|
|
89
|
+
"""Get the embedding size of the passed model
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
model_name (str): The name of the model to get embedding size for.
|
|
93
|
+
|
|
94
|
+
Returns:
|
|
95
|
+
int: The size of the embedding.
|
|
96
|
+
|
|
97
|
+
Raises:
|
|
98
|
+
ValueError: If the model name is not found in the supported models.
|
|
99
|
+
"""
|
|
100
|
+
descriptions = cls._list_supported_models()
|
|
101
|
+
embedding_size: int | None = None
|
|
102
|
+
for description in descriptions:
|
|
103
|
+
if description.model.lower() == model_name.lower():
|
|
104
|
+
embedding_size = description.dim
|
|
105
|
+
break
|
|
106
|
+
if embedding_size is None:
|
|
107
|
+
model_names = [description.model for description in descriptions]
|
|
108
|
+
raise ValueError(
|
|
109
|
+
f"Embedding size for model {model_name} was None. "
|
|
110
|
+
f"Available model names: {model_names}"
|
|
111
|
+
)
|
|
112
|
+
return embedding_size
|
|
113
|
+
|
|
114
|
+
def embed(
|
|
115
|
+
self,
|
|
116
|
+
images: ImageInput | Iterable[ImageInput],
|
|
117
|
+
batch_size: int = 16,
|
|
118
|
+
parallel: int | None = None,
|
|
119
|
+
**kwargs: Any,
|
|
120
|
+
) -> Iterable[NumpyArray]:
|
|
121
|
+
"""
|
|
122
|
+
Encode a list of images into list of embeddings.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
images: Iterator of image paths or single image path to embed
|
|
126
|
+
batch_size: Batch size for encoding -- higher values will use more memory, but be faster
|
|
127
|
+
parallel:
|
|
128
|
+
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
|
|
129
|
+
If 0, use all available cores.
|
|
130
|
+
If None, don't use data-parallel processing, use default onnxruntime threading instead.
|
|
131
|
+
|
|
132
|
+
Returns:
|
|
133
|
+
List of embeddings, one per document
|
|
134
|
+
"""
|
|
135
|
+
yield from self.model.embed(images, batch_size, parallel, **kwargs)
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
from typing import Iterable, Any
|
|
2
|
+
|
|
3
|
+
from fastembed.common.model_description import DenseModelDescription
|
|
4
|
+
from fastembed.common.types import NumpyArray
|
|
5
|
+
from fastembed.common.model_management import ModelManagement
|
|
6
|
+
from fastembed.common.types import ImageInput
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ImageEmbeddingBase(ModelManagement[DenseModelDescription]):
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
model_name: str,
|
|
13
|
+
cache_dir: str | None = None,
|
|
14
|
+
threads: int | None = None,
|
|
15
|
+
**kwargs: Any,
|
|
16
|
+
):
|
|
17
|
+
self.model_name = model_name
|
|
18
|
+
self.cache_dir = cache_dir
|
|
19
|
+
self.threads = threads
|
|
20
|
+
self._local_files_only = kwargs.pop("local_files_only", False)
|
|
21
|
+
self._embedding_size: int | None = None
|
|
22
|
+
|
|
23
|
+
def embed(
|
|
24
|
+
self,
|
|
25
|
+
images: ImageInput | Iterable[ImageInput],
|
|
26
|
+
batch_size: int = 16,
|
|
27
|
+
parallel: int | None = None,
|
|
28
|
+
**kwargs: Any,
|
|
29
|
+
) -> Iterable[NumpyArray]:
|
|
30
|
+
"""
|
|
31
|
+
Embeds a list of images into a list of embeddings.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
images: The list of image paths to preprocess and embed.
|
|
35
|
+
batch_size: Batch size for encoding
|
|
36
|
+
parallel:
|
|
37
|
+
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
|
|
38
|
+
If 0, use all available cores.
|
|
39
|
+
If None, don't use data-parallel processing, use default onnxruntime threading instead.
|
|
40
|
+
**kwargs: Additional keyword argument to pass to the embed method.
|
|
41
|
+
|
|
42
|
+
Yields:
|
|
43
|
+
Iterable[NdArray]: The embeddings.
|
|
44
|
+
"""
|
|
45
|
+
raise NotImplementedError()
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def get_embedding_size(cls, model_name: str) -> int:
|
|
49
|
+
"""Returns embedding size of the chosen model."""
|
|
50
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
51
|
+
|
|
52
|
+
@property
|
|
53
|
+
def embedding_size(self) -> int:
|
|
54
|
+
"""Returns embedding size for the current model"""
|
|
55
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
from typing import Any, Iterable, Sequence, Type
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
from fastembed.common.types import NumpyArray, Device
|
|
5
|
+
from fastembed.common import ImageInput, OnnxProvider
|
|
6
|
+
from fastembed.common.onnx_model import OnnxOutputContext
|
|
7
|
+
from fastembed.common.utils import define_cache_dir, normalize
|
|
8
|
+
from fastembed.image.image_embedding_base import ImageEmbeddingBase
|
|
9
|
+
from fastembed.image.onnx_image_model import ImageEmbeddingWorker, OnnxImageModel
|
|
10
|
+
|
|
11
|
+
from fastembed.common.model_description import DenseModelDescription, ModelSource
|
|
12
|
+
|
|
13
|
+
supported_onnx_models: list[DenseModelDescription] = [
|
|
14
|
+
DenseModelDescription(
|
|
15
|
+
model="Qdrant/clip-ViT-B-32-vision",
|
|
16
|
+
dim=512,
|
|
17
|
+
description="Image embeddings, Multimodal (text&image), 2021 year",
|
|
18
|
+
license="mit",
|
|
19
|
+
size_in_GB=0.34,
|
|
20
|
+
sources=ModelSource(hf="Qdrant/clip-ViT-B-32-vision"),
|
|
21
|
+
model_file="model.onnx",
|
|
22
|
+
),
|
|
23
|
+
DenseModelDescription(
|
|
24
|
+
model="Qdrant/resnet50-onnx",
|
|
25
|
+
dim=2048,
|
|
26
|
+
description="Image embeddings, Unimodal (image), 2016 year",
|
|
27
|
+
license="apache-2.0",
|
|
28
|
+
size_in_GB=0.1,
|
|
29
|
+
sources=ModelSource(hf="Qdrant/resnet50-onnx"),
|
|
30
|
+
model_file="model.onnx",
|
|
31
|
+
),
|
|
32
|
+
DenseModelDescription(
|
|
33
|
+
model="Qdrant/Unicom-ViT-B-16",
|
|
34
|
+
dim=768,
|
|
35
|
+
description="Image embeddings (more detailed than Unicom-ViT-B-32), Multimodal (text&image), 2023 year",
|
|
36
|
+
license="apache-2.0",
|
|
37
|
+
size_in_GB=0.82,
|
|
38
|
+
sources=ModelSource(hf="Qdrant/Unicom-ViT-B-16"),
|
|
39
|
+
model_file="model.onnx",
|
|
40
|
+
),
|
|
41
|
+
DenseModelDescription(
|
|
42
|
+
model="Qdrant/Unicom-ViT-B-32",
|
|
43
|
+
dim=512,
|
|
44
|
+
description="Image embeddings, Multimodal (text&image), 2023 year",
|
|
45
|
+
license="apache-2.0",
|
|
46
|
+
size_in_GB=0.48,
|
|
47
|
+
sources=ModelSource(hf="Qdrant/Unicom-ViT-B-32"),
|
|
48
|
+
model_file="model.onnx",
|
|
49
|
+
),
|
|
50
|
+
DenseModelDescription(
|
|
51
|
+
model="jinaai/jina-clip-v1",
|
|
52
|
+
dim=768,
|
|
53
|
+
description="Image embeddings, Multimodal (text&image), 2024 year",
|
|
54
|
+
license="apache-2.0",
|
|
55
|
+
size_in_GB=0.34,
|
|
56
|
+
sources=ModelSource(hf="jinaai/jina-clip-v1"),
|
|
57
|
+
model_file="onnx/vision_model.onnx",
|
|
58
|
+
),
|
|
59
|
+
]
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class OnnxImageEmbedding(ImageEmbeddingBase, OnnxImageModel[NumpyArray]):
|
|
63
|
+
def __init__(
|
|
64
|
+
self,
|
|
65
|
+
model_name: str,
|
|
66
|
+
|
|
67
|
+
cache_dir: str | None = None,
|
|
68
|
+
threads: int | None = None,
|
|
69
|
+
providers: Sequence[OnnxProvider] | None = None,
|
|
70
|
+
cuda: bool | Device = Device.AUTO,
|
|
71
|
+
device_ids: list[int] | None = None,
|
|
72
|
+
lazy_load: bool = False,
|
|
73
|
+
device_id: int | None = None,
|
|
74
|
+
specific_model_path: str | None = None,
|
|
75
|
+
**kwargs: Any,
|
|
76
|
+
):
|
|
77
|
+
"""
|
|
78
|
+
Args:
|
|
79
|
+
model_name (str): The name of the model to use.
|
|
80
|
+
cache_dir (str, optional): The path to the cache directory.
|
|
81
|
+
Can be set using the `FASTEMBED_CACHE_PATH` env variable.
|
|
82
|
+
Defaults to `fastembed_cache` in the system's temp directory.
|
|
83
|
+
threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
|
|
84
|
+
providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use.
|
|
85
|
+
Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
|
|
86
|
+
cuda (Union[bool, Device], optional): Whether to use cuda for inference. Mutually exclusive with `providers`
|
|
87
|
+
Defaults to Device.AUTO.
|
|
88
|
+
device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
|
|
89
|
+
workers. Should be used with `cuda` equals to `True`, `Device.AUTO` or `Device.CUDA`, mutually exclusive
|
|
90
|
+
with `providers`. Defaults to None.
|
|
91
|
+
lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
|
|
92
|
+
Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
|
|
93
|
+
device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
|
|
94
|
+
specific_model_path (Optional[str], optional): The specific path to the onnx model dir if it should be imported from somewhere else
|
|
95
|
+
|
|
96
|
+
Raises:
|
|
97
|
+
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
super().__init__(model_name, cache_dir, threads, **kwargs)
|
|
101
|
+
self.providers = providers
|
|
102
|
+
self.lazy_load = lazy_load
|
|
103
|
+
self._extra_session_options = self._select_exposed_session_options(kwargs)
|
|
104
|
+
|
|
105
|
+
# List of device ids, that can be used for data parallel processing in workers
|
|
106
|
+
self.device_ids = device_ids
|
|
107
|
+
self.cuda = cuda
|
|
108
|
+
|
|
109
|
+
# This device_id will be used if we need to load model in current process
|
|
110
|
+
self.device_id: int | None = None
|
|
111
|
+
if device_id is not None:
|
|
112
|
+
self.device_id = device_id
|
|
113
|
+
elif self.device_ids is not None:
|
|
114
|
+
self.device_id = self.device_ids[0]
|
|
115
|
+
|
|
116
|
+
self.model_description = self._get_model_description(model_name)
|
|
117
|
+
self.cache_dir = str(define_cache_dir(cache_dir))
|
|
118
|
+
self._specific_model_path = specific_model_path
|
|
119
|
+
self._model_dir = self.download_model(
|
|
120
|
+
self.model_description,
|
|
121
|
+
self.cache_dir,
|
|
122
|
+
local_files_only=self._local_files_only,
|
|
123
|
+
specific_model_path=self._specific_model_path,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
if not self.lazy_load:
|
|
127
|
+
self.load_onnx_model()
|
|
128
|
+
|
|
129
|
+
def load_onnx_model(self) -> None:
|
|
130
|
+
"""
|
|
131
|
+
Load the onnx model.
|
|
132
|
+
"""
|
|
133
|
+
self._load_onnx_model(
|
|
134
|
+
model_dir=self._model_dir,
|
|
135
|
+
model_file=self.model_description.model_file,
|
|
136
|
+
threads=self.threads,
|
|
137
|
+
providers=self.providers,
|
|
138
|
+
cuda=self.cuda,
|
|
139
|
+
device_id=self.device_id,
|
|
140
|
+
extra_session_options=self._extra_session_options,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
@classmethod
|
|
144
|
+
def _list_supported_models(cls) -> list[DenseModelDescription]:
|
|
145
|
+
"""
|
|
146
|
+
Lists the supported models.
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
|
|
150
|
+
"""
|
|
151
|
+
return supported_onnx_models
|
|
152
|
+
|
|
153
|
+
def embed(
|
|
154
|
+
self,
|
|
155
|
+
images: ImageInput | Iterable[ImageInput],
|
|
156
|
+
batch_size: int = 16,
|
|
157
|
+
parallel: int | None = None,
|
|
158
|
+
**kwargs: Any,
|
|
159
|
+
) -> Iterable[NumpyArray]:
|
|
160
|
+
"""
|
|
161
|
+
Encode a list of images into list of embeddings.
|
|
162
|
+
We use mean pooling with attention so that the model can handle variable-length inputs.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
images: Iterator of image paths or single image path to embed
|
|
166
|
+
batch_size: Batch size for encoding -- higher values will use more memory, but be faster
|
|
167
|
+
parallel:
|
|
168
|
+
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
|
|
169
|
+
If 0, use all available cores.
|
|
170
|
+
If None, don't use data-parallel processing, use default onnxruntime threading instead.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
List of embeddings, one per document
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
yield from self._embed_images(
|
|
177
|
+
model_name=self.model_name,
|
|
178
|
+
cache_dir=str(self.cache_dir),
|
|
179
|
+
images=images,
|
|
180
|
+
batch_size=batch_size,
|
|
181
|
+
parallel=parallel,
|
|
182
|
+
providers=self.providers,
|
|
183
|
+
cuda=self.cuda,
|
|
184
|
+
device_ids=self.device_ids,
|
|
185
|
+
local_files_only=self._local_files_only,
|
|
186
|
+
specific_model_path=self._specific_model_path,
|
|
187
|
+
extra_session_options=self._extra_session_options,
|
|
188
|
+
**kwargs,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
@classmethod
|
|
192
|
+
def _get_worker_class(cls) -> Type["ImageEmbeddingWorker[NumpyArray]"]:
|
|
193
|
+
return OnnxImageEmbeddingWorker
|
|
194
|
+
|
|
195
|
+
def _preprocess_onnx_input(
|
|
196
|
+
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
|
|
197
|
+
) -> dict[str, NumpyArray]:
|
|
198
|
+
"""
|
|
199
|
+
Preprocess the onnx input.
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
return onnx_input
|
|
203
|
+
|
|
204
|
+
def _post_process_onnx_output(
|
|
205
|
+
self, output: OnnxOutputContext, **kwargs: Any
|
|
206
|
+
) -> Iterable[NumpyArray]:
|
|
207
|
+
return normalize(output.model_output)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
class OnnxImageEmbeddingWorker(ImageEmbeddingWorker[NumpyArray]):
|
|
211
|
+
def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> OnnxImageEmbedding:
|
|
212
|
+
return OnnxImageEmbedding(
|
|
213
|
+
model_name=model_name,
|
|
214
|
+
cache_dir=cache_dir,
|
|
215
|
+
threads=1,
|
|
216
|
+
**kwargs,
|
|
217
|
+
)
|