fastembed-bio 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. fastembed/__init__.py +24 -0
  2. fastembed/bio/__init__.py +3 -0
  3. fastembed/bio/protein_embedding.py +456 -0
  4. fastembed/common/__init__.py +3 -0
  5. fastembed/common/model_description.py +52 -0
  6. fastembed/common/model_management.py +471 -0
  7. fastembed/common/onnx_model.py +188 -0
  8. fastembed/common/preprocessor_utils.py +84 -0
  9. fastembed/common/types.py +27 -0
  10. fastembed/common/utils.py +69 -0
  11. fastembed/embedding.py +24 -0
  12. fastembed/image/__init__.py +3 -0
  13. fastembed/image/image_embedding.py +135 -0
  14. fastembed/image/image_embedding_base.py +55 -0
  15. fastembed/image/onnx_embedding.py +217 -0
  16. fastembed/image/onnx_image_model.py +156 -0
  17. fastembed/image/transform/functional.py +221 -0
  18. fastembed/image/transform/operators.py +499 -0
  19. fastembed/late_interaction/__init__.py +5 -0
  20. fastembed/late_interaction/colbert.py +301 -0
  21. fastembed/late_interaction/jina_colbert.py +58 -0
  22. fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
  23. fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
  24. fastembed/late_interaction/token_embeddings.py +83 -0
  25. fastembed/late_interaction_multimodal/__init__.py +5 -0
  26. fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
  27. fastembed/late_interaction_multimodal/colpali.py +327 -0
  28. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
  29. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
  30. fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
  31. fastembed/parallel_processor.py +253 -0
  32. fastembed/postprocess/__init__.py +3 -0
  33. fastembed/postprocess/muvera.py +362 -0
  34. fastembed/py.typed +1 -0
  35. fastembed/rerank/cross_encoder/__init__.py +3 -0
  36. fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
  37. fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
  38. fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
  39. fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
  40. fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
  41. fastembed/sparse/__init__.py +4 -0
  42. fastembed/sparse/bm25.py +359 -0
  43. fastembed/sparse/bm42.py +369 -0
  44. fastembed/sparse/minicoil.py +372 -0
  45. fastembed/sparse/sparse_embedding_base.py +90 -0
  46. fastembed/sparse/sparse_text_embedding.py +143 -0
  47. fastembed/sparse/splade_pp.py +196 -0
  48. fastembed/sparse/utils/minicoil_encoder.py +146 -0
  49. fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
  50. fastembed/sparse/utils/tokenizer.py +120 -0
  51. fastembed/sparse/utils/vocab_resolver.py +202 -0
  52. fastembed/text/__init__.py +3 -0
  53. fastembed/text/clip_embedding.py +56 -0
  54. fastembed/text/custom_text_embedding.py +97 -0
  55. fastembed/text/multitask_embedding.py +109 -0
  56. fastembed/text/onnx_embedding.py +353 -0
  57. fastembed/text/onnx_text_model.py +180 -0
  58. fastembed/text/pooled_embedding.py +136 -0
  59. fastembed/text/pooled_normalized_embedding.py +164 -0
  60. fastembed/text/text_embedding.py +228 -0
  61. fastembed/text/text_embedding_base.py +75 -0
  62. fastembed_bio-0.1.0.dist-info/METADATA +339 -0
  63. fastembed_bio-0.1.0.dist-info/RECORD +66 -0
  64. fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
  65. fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
  66. fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
@@ -0,0 +1,84 @@
1
+ import json
2
+ from typing import Any
3
+ from pathlib import Path
4
+
5
+ from tokenizers import AddedToken, Tokenizer
6
+
7
+ from fastembed.image.transform.operators import Compose
8
+
9
+
10
+ def load_special_tokens(model_dir: Path) -> dict[str, Any]:
11
+ tokens_map_path = model_dir / "special_tokens_map.json"
12
+ if not tokens_map_path.exists():
13
+ raise ValueError(f"Could not find special_tokens_map.json in {model_dir}")
14
+
15
+ with open(str(tokens_map_path)) as tokens_map_file:
16
+ tokens_map = json.load(tokens_map_file)
17
+
18
+ return tokens_map
19
+
20
+
21
+ def load_tokenizer(model_dir: Path) -> tuple[Tokenizer, dict[str, int]]:
22
+ config_path = model_dir / "config.json"
23
+ if not config_path.exists():
24
+ raise ValueError(f"Could not find config.json in {model_dir}")
25
+
26
+ tokenizer_path = model_dir / "tokenizer.json"
27
+ if not tokenizer_path.exists():
28
+ raise ValueError(f"Could not find tokenizer.json in {model_dir}")
29
+
30
+ tokenizer_config_path = model_dir / "tokenizer_config.json"
31
+ if not tokenizer_config_path.exists():
32
+ raise ValueError(f"Could not find tokenizer_config.json in {model_dir}")
33
+
34
+ with open(str(config_path)) as config_file:
35
+ config = json.load(config_file)
36
+
37
+ with open(str(tokenizer_config_path)) as tokenizer_config_file:
38
+ tokenizer_config = json.load(tokenizer_config_file)
39
+ assert "model_max_length" in tokenizer_config or "max_length" in tokenizer_config, (
40
+ "Models without model_max_length or max_length are not supported."
41
+ )
42
+ if "model_max_length" not in tokenizer_config:
43
+ max_context = tokenizer_config["max_length"]
44
+ elif "max_length" not in tokenizer_config:
45
+ max_context = tokenizer_config["model_max_length"]
46
+ else:
47
+ max_context = min(tokenizer_config["model_max_length"], tokenizer_config["max_length"])
48
+
49
+ tokens_map = load_special_tokens(model_dir)
50
+
51
+ tokenizer = Tokenizer.from_file(str(tokenizer_path))
52
+ tokenizer.enable_truncation(max_length=max_context)
53
+ if not tokenizer.padding:
54
+ tokenizer.enable_padding(
55
+ pad_id=config.get("pad_token_id", 0), pad_token=tokenizer_config["pad_token"]
56
+ )
57
+
58
+ for token in tokens_map.values():
59
+ if isinstance(token, str):
60
+ tokenizer.add_special_tokens([token])
61
+ elif isinstance(token, dict):
62
+ tokenizer.add_special_tokens([AddedToken(**token)])
63
+
64
+ special_token_to_id: dict[str, int] = {}
65
+
66
+ for token in tokens_map.values():
67
+ if isinstance(token, str):
68
+ special_token_to_id[token] = tokenizer.token_to_id(token)
69
+ elif isinstance(token, dict):
70
+ token_str = token.get("content", "")
71
+ special_token_to_id[token_str] = tokenizer.token_to_id(token_str)
72
+
73
+ return tokenizer, special_token_to_id
74
+
75
+
76
+ def load_preprocessor(model_dir: Path) -> Compose:
77
+ preprocessor_config_path = model_dir / "preprocessor_config.json"
78
+ if not preprocessor_config_path.exists():
79
+ raise ValueError(f"Could not find preprocessor_config.json in {model_dir}")
80
+
81
+ with open(str(preprocessor_config_path)) as preprocessor_config_file:
82
+ preprocessor_config = json.load(preprocessor_config_file)
83
+ transforms = Compose.from_config(preprocessor_config)
84
+ return transforms
@@ -0,0 +1,27 @@
1
+ from enum import Enum
2
+ from pathlib import Path
3
+ from typing import Any, TypeAlias
4
+
5
+ import numpy as np
6
+ from numpy.typing import NDArray
7
+ from PIL import Image
8
+
9
+
10
+ class Device(str, Enum):
11
+ CPU = "cpu"
12
+ CUDA = "cuda"
13
+ AUTO = "auto"
14
+
15
+
16
+ PathInput: TypeAlias = str | Path
17
+ ImageInput: TypeAlias = PathInput | Image.Image
18
+
19
+ OnnxProvider: TypeAlias = str | tuple[str, dict[Any, Any]]
20
+ NumpyArray: TypeAlias = (
21
+ NDArray[np.float64]
22
+ | NDArray[np.float32]
23
+ | NDArray[np.float16]
24
+ | NDArray[np.int8]
25
+ | NDArray[np.int64]
26
+ | NDArray[np.int32]
27
+ )
@@ -0,0 +1,69 @@
1
+ import os
2
+ import sys
3
+ import re
4
+ import tempfile
5
+ import unicodedata
6
+ from pathlib import Path
7
+ from itertools import islice
8
+ from typing import Iterable, TypeVar
9
+
10
+ import numpy as np
11
+ from numpy.typing import NDArray
12
+
13
+ from fastembed.common.types import NumpyArray
14
+
15
+ T = TypeVar("T")
16
+
17
+
18
+ def normalize(input_array: NumpyArray, p: int = 2, dim: int = 1, eps: float = 1e-12) -> NumpyArray:
19
+ # Calculate the Lp norm along the specified dimension
20
+ norm = np.linalg.norm(input_array, ord=p, axis=dim, keepdims=True)
21
+ norm = np.maximum(norm, eps) # Avoid division by zero
22
+ normalized_array = input_array / norm
23
+ return normalized_array
24
+
25
+
26
+ def mean_pooling(input_array: NumpyArray, attention_mask: NDArray[np.int64]) -> NumpyArray:
27
+ input_mask_expanded = np.expand_dims(attention_mask, axis=-1).astype(np.int64)
28
+ input_mask_expanded = np.tile(input_mask_expanded, (1, 1, input_array.shape[-1]))
29
+ sum_embeddings = np.sum(input_array * input_mask_expanded, axis=1)
30
+ sum_mask = np.sum(input_mask_expanded, axis=1)
31
+ pooled_embeddings = sum_embeddings / np.maximum(sum_mask, 1e-9)
32
+ return pooled_embeddings
33
+
34
+
35
+ def iter_batch(iterable: Iterable[T], size: int) -> Iterable[list[T]]:
36
+ """
37
+ >>> list(iter_batch([1,2,3,4,5], 3))
38
+ [[1, 2, 3], [4, 5]]
39
+ """
40
+ source_iter = iter(iterable)
41
+ while source_iter:
42
+ b = list(islice(source_iter, size))
43
+ if len(b) == 0:
44
+ break
45
+ yield b
46
+
47
+
48
+ def define_cache_dir(cache_dir: str | None = None) -> Path:
49
+ """
50
+ Define the cache directory for fastembed
51
+ """
52
+ if cache_dir is None:
53
+ default_cache_dir = os.path.join(tempfile.gettempdir(), "fastembed_cache")
54
+ cache_path = Path(os.getenv("FASTEMBED_CACHE_PATH", default_cache_dir))
55
+ else:
56
+ cache_path = Path(cache_dir)
57
+ cache_path.mkdir(parents=True, exist_ok=True)
58
+
59
+ return cache_path
60
+
61
+
62
+ def get_all_punctuation() -> set[str]:
63
+ return set(
64
+ chr(i) for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P")
65
+ )
66
+
67
+
68
+ def remove_non_alphanumeric(text: str) -> str:
69
+ return re.sub(r"[^\w\s]", " ", text, flags=re.UNICODE)
fastembed/embedding.py ADDED
@@ -0,0 +1,24 @@
1
+ from typing import Any
2
+
3
+ from loguru import logger
4
+
5
+ from fastembed import TextEmbedding
6
+
7
+ logger.warning(
8
+ "DefaultEmbedding, FlagEmbedding, JinaEmbedding are deprecated."
9
+ "Use from fastembed import TextEmbedding instead."
10
+ )
11
+
12
+ DefaultEmbedding = TextEmbedding
13
+ FlagEmbedding = TextEmbedding
14
+
15
+
16
+ class JinaEmbedding(TextEmbedding):
17
+ def __init__(
18
+ self,
19
+ model_name: str = "jinaai/jina-embeddings-v2-base-en",
20
+ cache_dir: str | None = None,
21
+ threads: int | None = None,
22
+ **kwargs: Any,
23
+ ):
24
+ super().__init__(model_name, cache_dir, threads, **kwargs)
@@ -0,0 +1,3 @@
1
+ from fastembed.image.image_embedding import ImageEmbedding
2
+
3
+ __all__ = ["ImageEmbedding"]
@@ -0,0 +1,135 @@
1
+ from typing import Any, Iterable, Sequence, Type
2
+ from dataclasses import asdict
3
+
4
+ from fastembed.common.types import NumpyArray, Device
5
+ from fastembed.common import ImageInput, OnnxProvider
6
+ from fastembed.image.image_embedding_base import ImageEmbeddingBase
7
+ from fastembed.image.onnx_embedding import OnnxImageEmbedding
8
+ from fastembed.common.model_description import DenseModelDescription
9
+
10
+
11
+ class ImageEmbedding(ImageEmbeddingBase):
12
+ EMBEDDINGS_REGISTRY: list[Type[ImageEmbeddingBase]] = [OnnxImageEmbedding]
13
+
14
+ @classmethod
15
+ def list_supported_models(cls) -> list[dict[str, Any]]:
16
+ """
17
+ Lists the supported models.
18
+
19
+ Returns:
20
+ list[dict[str, Any]]: A list of dictionaries containing the model information.
21
+
22
+ Example:
23
+ ```
24
+ [
25
+ {
26
+ "model": "Qdrant/clip-ViT-B-32-vision",
27
+ "dim": 512,
28
+ "description": "CLIP vision encoder based on ViT-B/32",
29
+ "license": "mit",
30
+ "size_in_GB": 0.33,
31
+ "sources": {
32
+ "hf": "Qdrant/clip-ViT-B-32-vision",
33
+ },
34
+ "model_file": "model.onnx",
35
+ }
36
+ ]
37
+ ```
38
+ """
39
+ return [asdict(model) for model in cls._list_supported_models()]
40
+
41
+ @classmethod
42
+ def _list_supported_models(cls) -> list[DenseModelDescription]:
43
+ result: list[DenseModelDescription] = []
44
+ for embedding in cls.EMBEDDINGS_REGISTRY:
45
+ result.extend(embedding._list_supported_models())
46
+ return result
47
+
48
+ def __init__(
49
+ self,
50
+ model_name: str,
51
+ cache_dir: str | None = None,
52
+ threads: int | None = None,
53
+ providers: Sequence[OnnxProvider] | None = None,
54
+ cuda: bool | Device = Device.AUTO,
55
+ device_ids: list[int] | None = None,
56
+ lazy_load: bool = False,
57
+ **kwargs: Any,
58
+ ):
59
+ super().__init__(model_name, cache_dir, threads, **kwargs)
60
+ for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY:
61
+ supported_models = EMBEDDING_MODEL_TYPE._list_supported_models()
62
+ if any(model_name.lower() == model.model.lower() for model in supported_models):
63
+ self.model = EMBEDDING_MODEL_TYPE(
64
+ model_name,
65
+ cache_dir,
66
+ threads=threads,
67
+ providers=providers,
68
+ cuda=cuda,
69
+ device_ids=device_ids,
70
+ lazy_load=lazy_load,
71
+ **kwargs,
72
+ )
73
+ return
74
+
75
+ raise ValueError(
76
+ f"Model {model_name} is not supported in ImageEmbedding."
77
+ "Please check the supported models using `ImageEmbedding.list_supported_models()`"
78
+ )
79
+
80
+ @property
81
+ def embedding_size(self) -> int:
82
+ """Get the embedding size of the current model"""
83
+ if self._embedding_size is None:
84
+ self._embedding_size = self.get_embedding_size(self.model_name)
85
+ return self._embedding_size
86
+
87
+ @classmethod
88
+ def get_embedding_size(cls, model_name: str) -> int:
89
+ """Get the embedding size of the passed model
90
+
91
+ Args:
92
+ model_name (str): The name of the model to get embedding size for.
93
+
94
+ Returns:
95
+ int: The size of the embedding.
96
+
97
+ Raises:
98
+ ValueError: If the model name is not found in the supported models.
99
+ """
100
+ descriptions = cls._list_supported_models()
101
+ embedding_size: int | None = None
102
+ for description in descriptions:
103
+ if description.model.lower() == model_name.lower():
104
+ embedding_size = description.dim
105
+ break
106
+ if embedding_size is None:
107
+ model_names = [description.model for description in descriptions]
108
+ raise ValueError(
109
+ f"Embedding size for model {model_name} was None. "
110
+ f"Available model names: {model_names}"
111
+ )
112
+ return embedding_size
113
+
114
+ def embed(
115
+ self,
116
+ images: ImageInput | Iterable[ImageInput],
117
+ batch_size: int = 16,
118
+ parallel: int | None = None,
119
+ **kwargs: Any,
120
+ ) -> Iterable[NumpyArray]:
121
+ """
122
+ Encode a list of images into list of embeddings.
123
+
124
+ Args:
125
+ images: Iterator of image paths or single image path to embed
126
+ batch_size: Batch size for encoding -- higher values will use more memory, but be faster
127
+ parallel:
128
+ If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
129
+ If 0, use all available cores.
130
+ If None, don't use data-parallel processing, use default onnxruntime threading instead.
131
+
132
+ Returns:
133
+ List of embeddings, one per document
134
+ """
135
+ yield from self.model.embed(images, batch_size, parallel, **kwargs)
@@ -0,0 +1,55 @@
1
+ from typing import Iterable, Any
2
+
3
+ from fastembed.common.model_description import DenseModelDescription
4
+ from fastembed.common.types import NumpyArray
5
+ from fastembed.common.model_management import ModelManagement
6
+ from fastembed.common.types import ImageInput
7
+
8
+
9
+ class ImageEmbeddingBase(ModelManagement[DenseModelDescription]):
10
+ def __init__(
11
+ self,
12
+ model_name: str,
13
+ cache_dir: str | None = None,
14
+ threads: int | None = None,
15
+ **kwargs: Any,
16
+ ):
17
+ self.model_name = model_name
18
+ self.cache_dir = cache_dir
19
+ self.threads = threads
20
+ self._local_files_only = kwargs.pop("local_files_only", False)
21
+ self._embedding_size: int | None = None
22
+
23
+ def embed(
24
+ self,
25
+ images: ImageInput | Iterable[ImageInput],
26
+ batch_size: int = 16,
27
+ parallel: int | None = None,
28
+ **kwargs: Any,
29
+ ) -> Iterable[NumpyArray]:
30
+ """
31
+ Embeds a list of images into a list of embeddings.
32
+
33
+ Args:
34
+ images: The list of image paths to preprocess and embed.
35
+ batch_size: Batch size for encoding
36
+ parallel:
37
+ If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
38
+ If 0, use all available cores.
39
+ If None, don't use data-parallel processing, use default onnxruntime threading instead.
40
+ **kwargs: Additional keyword argument to pass to the embed method.
41
+
42
+ Yields:
43
+ Iterable[NdArray]: The embeddings.
44
+ """
45
+ raise NotImplementedError()
46
+
47
+ @classmethod
48
+ def get_embedding_size(cls, model_name: str) -> int:
49
+ """Returns embedding size of the chosen model."""
50
+ raise NotImplementedError("Subclasses must implement this method")
51
+
52
+ @property
53
+ def embedding_size(self) -> int:
54
+ """Returns embedding size for the current model"""
55
+ raise NotImplementedError("Subclasses must implement this method")
@@ -0,0 +1,217 @@
1
+ from typing import Any, Iterable, Sequence, Type
2
+
3
+
4
+ from fastembed.common.types import NumpyArray, Device
5
+ from fastembed.common import ImageInput, OnnxProvider
6
+ from fastembed.common.onnx_model import OnnxOutputContext
7
+ from fastembed.common.utils import define_cache_dir, normalize
8
+ from fastembed.image.image_embedding_base import ImageEmbeddingBase
9
+ from fastembed.image.onnx_image_model import ImageEmbeddingWorker, OnnxImageModel
10
+
11
+ from fastembed.common.model_description import DenseModelDescription, ModelSource
12
+
13
+ supported_onnx_models: list[DenseModelDescription] = [
14
+ DenseModelDescription(
15
+ model="Qdrant/clip-ViT-B-32-vision",
16
+ dim=512,
17
+ description="Image embeddings, Multimodal (text&image), 2021 year",
18
+ license="mit",
19
+ size_in_GB=0.34,
20
+ sources=ModelSource(hf="Qdrant/clip-ViT-B-32-vision"),
21
+ model_file="model.onnx",
22
+ ),
23
+ DenseModelDescription(
24
+ model="Qdrant/resnet50-onnx",
25
+ dim=2048,
26
+ description="Image embeddings, Unimodal (image), 2016 year",
27
+ license="apache-2.0",
28
+ size_in_GB=0.1,
29
+ sources=ModelSource(hf="Qdrant/resnet50-onnx"),
30
+ model_file="model.onnx",
31
+ ),
32
+ DenseModelDescription(
33
+ model="Qdrant/Unicom-ViT-B-16",
34
+ dim=768,
35
+ description="Image embeddings (more detailed than Unicom-ViT-B-32), Multimodal (text&image), 2023 year",
36
+ license="apache-2.0",
37
+ size_in_GB=0.82,
38
+ sources=ModelSource(hf="Qdrant/Unicom-ViT-B-16"),
39
+ model_file="model.onnx",
40
+ ),
41
+ DenseModelDescription(
42
+ model="Qdrant/Unicom-ViT-B-32",
43
+ dim=512,
44
+ description="Image embeddings, Multimodal (text&image), 2023 year",
45
+ license="apache-2.0",
46
+ size_in_GB=0.48,
47
+ sources=ModelSource(hf="Qdrant/Unicom-ViT-B-32"),
48
+ model_file="model.onnx",
49
+ ),
50
+ DenseModelDescription(
51
+ model="jinaai/jina-clip-v1",
52
+ dim=768,
53
+ description="Image embeddings, Multimodal (text&image), 2024 year",
54
+ license="apache-2.0",
55
+ size_in_GB=0.34,
56
+ sources=ModelSource(hf="jinaai/jina-clip-v1"),
57
+ model_file="onnx/vision_model.onnx",
58
+ ),
59
+ ]
60
+
61
+
62
+ class OnnxImageEmbedding(ImageEmbeddingBase, OnnxImageModel[NumpyArray]):
63
+ def __init__(
64
+ self,
65
+ model_name: str,
66
+
67
+ cache_dir: str | None = None,
68
+ threads: int | None = None,
69
+ providers: Sequence[OnnxProvider] | None = None,
70
+ cuda: bool | Device = Device.AUTO,
71
+ device_ids: list[int] | None = None,
72
+ lazy_load: bool = False,
73
+ device_id: int | None = None,
74
+ specific_model_path: str | None = None,
75
+ **kwargs: Any,
76
+ ):
77
+ """
78
+ Args:
79
+ model_name (str): The name of the model to use.
80
+ cache_dir (str, optional): The path to the cache directory.
81
+ Can be set using the `FASTEMBED_CACHE_PATH` env variable.
82
+ Defaults to `fastembed_cache` in the system's temp directory.
83
+ threads (int, optional): The number of threads single onnxruntime session can use. Defaults to None.
84
+ providers (Optional[Sequence[OnnxProvider]], optional): The list of onnxruntime providers to use.
85
+ Mutually exclusive with the `cuda` and `device_ids` arguments. Defaults to None.
86
+ cuda (Union[bool, Device], optional): Whether to use cuda for inference. Mutually exclusive with `providers`
87
+ Defaults to Device.AUTO.
88
+ device_ids (Optional[list[int]], optional): The list of device ids to use for data parallel processing in
89
+ workers. Should be used with `cuda` equals to `True`, `Device.AUTO` or `Device.CUDA`, mutually exclusive
90
+ with `providers`. Defaults to None.
91
+ lazy_load (bool, optional): Whether to load the model during class initialization or on demand.
92
+ Should be set to True when using multiple-gpu and parallel encoding. Defaults to False.
93
+ device_id (Optional[int], optional): The device id to use for loading the model in the worker process.
94
+ specific_model_path (Optional[str], optional): The specific path to the onnx model dir if it should be imported from somewhere else
95
+
96
+ Raises:
97
+ ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
98
+ """
99
+
100
+ super().__init__(model_name, cache_dir, threads, **kwargs)
101
+ self.providers = providers
102
+ self.lazy_load = lazy_load
103
+ self._extra_session_options = self._select_exposed_session_options(kwargs)
104
+
105
+ # List of device ids, that can be used for data parallel processing in workers
106
+ self.device_ids = device_ids
107
+ self.cuda = cuda
108
+
109
+ # This device_id will be used if we need to load model in current process
110
+ self.device_id: int | None = None
111
+ if device_id is not None:
112
+ self.device_id = device_id
113
+ elif self.device_ids is not None:
114
+ self.device_id = self.device_ids[0]
115
+
116
+ self.model_description = self._get_model_description(model_name)
117
+ self.cache_dir = str(define_cache_dir(cache_dir))
118
+ self._specific_model_path = specific_model_path
119
+ self._model_dir = self.download_model(
120
+ self.model_description,
121
+ self.cache_dir,
122
+ local_files_only=self._local_files_only,
123
+ specific_model_path=self._specific_model_path,
124
+ )
125
+
126
+ if not self.lazy_load:
127
+ self.load_onnx_model()
128
+
129
+ def load_onnx_model(self) -> None:
130
+ """
131
+ Load the onnx model.
132
+ """
133
+ self._load_onnx_model(
134
+ model_dir=self._model_dir,
135
+ model_file=self.model_description.model_file,
136
+ threads=self.threads,
137
+ providers=self.providers,
138
+ cuda=self.cuda,
139
+ device_id=self.device_id,
140
+ extra_session_options=self._extra_session_options,
141
+ )
142
+
143
+ @classmethod
144
+ def _list_supported_models(cls) -> list[DenseModelDescription]:
145
+ """
146
+ Lists the supported models.
147
+
148
+ Returns:
149
+ list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
150
+ """
151
+ return supported_onnx_models
152
+
153
+ def embed(
154
+ self,
155
+ images: ImageInput | Iterable[ImageInput],
156
+ batch_size: int = 16,
157
+ parallel: int | None = None,
158
+ **kwargs: Any,
159
+ ) -> Iterable[NumpyArray]:
160
+ """
161
+ Encode a list of images into list of embeddings.
162
+ We use mean pooling with attention so that the model can handle variable-length inputs.
163
+
164
+ Args:
165
+ images: Iterator of image paths or single image path to embed
166
+ batch_size: Batch size for encoding -- higher values will use more memory, but be faster
167
+ parallel:
168
+ If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
169
+ If 0, use all available cores.
170
+ If None, don't use data-parallel processing, use default onnxruntime threading instead.
171
+
172
+ Returns:
173
+ List of embeddings, one per document
174
+ """
175
+
176
+ yield from self._embed_images(
177
+ model_name=self.model_name,
178
+ cache_dir=str(self.cache_dir),
179
+ images=images,
180
+ batch_size=batch_size,
181
+ parallel=parallel,
182
+ providers=self.providers,
183
+ cuda=self.cuda,
184
+ device_ids=self.device_ids,
185
+ local_files_only=self._local_files_only,
186
+ specific_model_path=self._specific_model_path,
187
+ extra_session_options=self._extra_session_options,
188
+ **kwargs,
189
+ )
190
+
191
+ @classmethod
192
+ def _get_worker_class(cls) -> Type["ImageEmbeddingWorker[NumpyArray]"]:
193
+ return OnnxImageEmbeddingWorker
194
+
195
+ def _preprocess_onnx_input(
196
+ self, onnx_input: dict[str, NumpyArray], **kwargs: Any
197
+ ) -> dict[str, NumpyArray]:
198
+ """
199
+ Preprocess the onnx input.
200
+ """
201
+
202
+ return onnx_input
203
+
204
+ def _post_process_onnx_output(
205
+ self, output: OnnxOutputContext, **kwargs: Any
206
+ ) -> Iterable[NumpyArray]:
207
+ return normalize(output.model_output)
208
+
209
+
210
+ class OnnxImageEmbeddingWorker(ImageEmbeddingWorker[NumpyArray]):
211
+ def init_embedding(self, model_name: str, cache_dir: str, **kwargs: Any) -> OnnxImageEmbedding:
212
+ return OnnxImageEmbedding(
213
+ model_name=model_name,
214
+ cache_dir=cache_dir,
215
+ threads=1,
216
+ **kwargs,
217
+ )