fastembed-bio 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastembed/__init__.py +24 -0
- fastembed/bio/__init__.py +3 -0
- fastembed/bio/protein_embedding.py +456 -0
- fastembed/common/__init__.py +3 -0
- fastembed/common/model_description.py +52 -0
- fastembed/common/model_management.py +471 -0
- fastembed/common/onnx_model.py +188 -0
- fastembed/common/preprocessor_utils.py +84 -0
- fastembed/common/types.py +27 -0
- fastembed/common/utils.py +69 -0
- fastembed/embedding.py +24 -0
- fastembed/image/__init__.py +3 -0
- fastembed/image/image_embedding.py +135 -0
- fastembed/image/image_embedding_base.py +55 -0
- fastembed/image/onnx_embedding.py +217 -0
- fastembed/image/onnx_image_model.py +156 -0
- fastembed/image/transform/functional.py +221 -0
- fastembed/image/transform/operators.py +499 -0
- fastembed/late_interaction/__init__.py +5 -0
- fastembed/late_interaction/colbert.py +301 -0
- fastembed/late_interaction/jina_colbert.py +58 -0
- fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
- fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
- fastembed/late_interaction/token_embeddings.py +83 -0
- fastembed/late_interaction_multimodal/__init__.py +5 -0
- fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
- fastembed/late_interaction_multimodal/colpali.py +327 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
- fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
- fastembed/parallel_processor.py +253 -0
- fastembed/postprocess/__init__.py +3 -0
- fastembed/postprocess/muvera.py +362 -0
- fastembed/py.typed +1 -0
- fastembed/rerank/cross_encoder/__init__.py +3 -0
- fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
- fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
- fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
- fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
- fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
- fastembed/sparse/__init__.py +4 -0
- fastembed/sparse/bm25.py +359 -0
- fastembed/sparse/bm42.py +369 -0
- fastembed/sparse/minicoil.py +372 -0
- fastembed/sparse/sparse_embedding_base.py +90 -0
- fastembed/sparse/sparse_text_embedding.py +143 -0
- fastembed/sparse/splade_pp.py +196 -0
- fastembed/sparse/utils/minicoil_encoder.py +146 -0
- fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
- fastembed/sparse/utils/tokenizer.py +120 -0
- fastembed/sparse/utils/vocab_resolver.py +202 -0
- fastembed/text/__init__.py +3 -0
- fastembed/text/clip_embedding.py +56 -0
- fastembed/text/custom_text_embedding.py +97 -0
- fastembed/text/multitask_embedding.py +109 -0
- fastembed/text/onnx_embedding.py +353 -0
- fastembed/text/onnx_text_model.py +180 -0
- fastembed/text/pooled_embedding.py +136 -0
- fastembed/text/pooled_normalized_embedding.py +164 -0
- fastembed/text/text_embedding.py +228 -0
- fastembed/text/text_embedding_base.py +75 -0
- fastembed_bio-0.1.0.dist-info/METADATA +339 -0
- fastembed_bio-0.1.0.dist-info/RECORD +66 -0
- fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
- fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
- fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
from typing import Any, Iterable, Type
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
from fastembed.common.types import NumpyArray
|
|
5
|
+
from fastembed.common.onnx_model import OnnxOutputContext
|
|
6
|
+
from fastembed.common.utils import normalize
|
|
7
|
+
from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker
|
|
8
|
+
from fastembed.text.pooled_embedding import PooledEmbedding
|
|
9
|
+
from fastembed.common.model_description import DenseModelDescription, ModelSource
|
|
10
|
+
|
|
11
|
+
supported_pooled_normalized_models: list[DenseModelDescription] = [
|
|
12
|
+
DenseModelDescription(
|
|
13
|
+
model="sentence-transformers/all-MiniLM-L6-v2",
|
|
14
|
+
dim=384,
|
|
15
|
+
description=(
|
|
16
|
+
"Text embeddings, Unimodal (text), English, 256 input tokens truncation, "
|
|
17
|
+
"Prefixes for queries/documents: not necessary, 2021 year."
|
|
18
|
+
),
|
|
19
|
+
license="apache-2.0",
|
|
20
|
+
size_in_GB=0.09,
|
|
21
|
+
sources=ModelSource(
|
|
22
|
+
url="https://storage.googleapis.com/qdrant-fastembed/sentence-transformers-all-MiniLM-L6-v2.tar.gz",
|
|
23
|
+
hf="qdrant/all-MiniLM-L6-v2-onnx",
|
|
24
|
+
_deprecated_tar_struct=True,
|
|
25
|
+
),
|
|
26
|
+
model_file="model.onnx",
|
|
27
|
+
),
|
|
28
|
+
DenseModelDescription(
|
|
29
|
+
model="jinaai/jina-embeddings-v2-base-en",
|
|
30
|
+
dim=768,
|
|
31
|
+
description=(
|
|
32
|
+
"Text embeddings, Unimodal (text), English, 8192 input tokens truncation, "
|
|
33
|
+
"Prefixes for queries/documents: not necessary, 2023 year."
|
|
34
|
+
),
|
|
35
|
+
license="apache-2.0",
|
|
36
|
+
size_in_GB=0.52,
|
|
37
|
+
sources=ModelSource(hf="xenova/jina-embeddings-v2-base-en"),
|
|
38
|
+
model_file="onnx/model.onnx",
|
|
39
|
+
),
|
|
40
|
+
DenseModelDescription(
|
|
41
|
+
model="jinaai/jina-embeddings-v2-small-en",
|
|
42
|
+
dim=512,
|
|
43
|
+
description=(
|
|
44
|
+
"Text embeddings, Unimodal (text), English, 8192 input tokens truncation, "
|
|
45
|
+
"Prefixes for queries/documents: not necessary, 2023 year."
|
|
46
|
+
),
|
|
47
|
+
license="apache-2.0",
|
|
48
|
+
size_in_GB=0.12,
|
|
49
|
+
sources=ModelSource(hf="xenova/jina-embeddings-v2-small-en"),
|
|
50
|
+
model_file="onnx/model.onnx",
|
|
51
|
+
),
|
|
52
|
+
DenseModelDescription(
|
|
53
|
+
model="jinaai/jina-embeddings-v2-base-de",
|
|
54
|
+
dim=768,
|
|
55
|
+
description=(
|
|
56
|
+
"Text embeddings, Unimodal (text), Multilingual (German, English), 8192 input tokens truncation, "
|
|
57
|
+
"Prefixes for queries/documents: not necessary, 2024 year."
|
|
58
|
+
),
|
|
59
|
+
license="apache-2.0",
|
|
60
|
+
size_in_GB=0.32,
|
|
61
|
+
sources=ModelSource(hf="jinaai/jina-embeddings-v2-base-de"),
|
|
62
|
+
model_file="onnx/model_fp16.onnx",
|
|
63
|
+
),
|
|
64
|
+
DenseModelDescription(
|
|
65
|
+
model="jinaai/jina-embeddings-v2-base-code",
|
|
66
|
+
dim=768,
|
|
67
|
+
description=(
|
|
68
|
+
"Text embeddings, Unimodal (text), Multilingual (English, 30 programming languages), "
|
|
69
|
+
"8192 input tokens truncation, Prefixes for queries/documents: not necessary, 2024 year."
|
|
70
|
+
),
|
|
71
|
+
license="apache-2.0",
|
|
72
|
+
size_in_GB=0.64,
|
|
73
|
+
sources=ModelSource(hf="jinaai/jina-embeddings-v2-base-code"),
|
|
74
|
+
model_file="onnx/model.onnx",
|
|
75
|
+
),
|
|
76
|
+
DenseModelDescription(
|
|
77
|
+
model="jinaai/jina-embeddings-v2-base-zh",
|
|
78
|
+
dim=768,
|
|
79
|
+
description=(
|
|
80
|
+
"Text embeddings, Unimodal (text), supports mixed Chinese-English input text, "
|
|
81
|
+
"8192 input tokens truncation, Prefixes for queries/documents: not necessary, 2024 year."
|
|
82
|
+
),
|
|
83
|
+
license="apache-2.0",
|
|
84
|
+
size_in_GB=0.64,
|
|
85
|
+
sources=ModelSource(hf="jinaai/jina-embeddings-v2-base-zh"),
|
|
86
|
+
model_file="onnx/model.onnx",
|
|
87
|
+
),
|
|
88
|
+
DenseModelDescription(
|
|
89
|
+
model="jinaai/jina-embeddings-v2-base-es",
|
|
90
|
+
dim=768,
|
|
91
|
+
description=(
|
|
92
|
+
"Text embeddings, Unimodal (text), supports mixed Spanish-English input text, "
|
|
93
|
+
"8192 input tokens truncation, Prefixes for queries/documents: not necessary, 2024 year."
|
|
94
|
+
),
|
|
95
|
+
license="apache-2.0",
|
|
96
|
+
size_in_GB=0.64,
|
|
97
|
+
sources=ModelSource(hf="jinaai/jina-embeddings-v2-base-es"),
|
|
98
|
+
model_file="onnx/model.onnx",
|
|
99
|
+
),
|
|
100
|
+
DenseModelDescription(
|
|
101
|
+
model="thenlper/gte-base",
|
|
102
|
+
dim=768,
|
|
103
|
+
description=(
|
|
104
|
+
"General text embeddings, Unimodal (text), supports English only input text, "
|
|
105
|
+
"512 input tokens truncation, Prefixes for queries/documents: not necessary, 2024 year."
|
|
106
|
+
),
|
|
107
|
+
license="mit",
|
|
108
|
+
size_in_GB=0.44,
|
|
109
|
+
sources=ModelSource(hf="thenlper/gte-base"),
|
|
110
|
+
model_file="onnx/model.onnx",
|
|
111
|
+
),
|
|
112
|
+
DenseModelDescription(
|
|
113
|
+
model="thenlper/gte-large",
|
|
114
|
+
dim=1024,
|
|
115
|
+
description=(
|
|
116
|
+
"Text embeddings, Unimodal (text), English, 512 input tokens truncation, "
|
|
117
|
+
"Prefixes for queries/documents: not necessary, 2023 year."
|
|
118
|
+
),
|
|
119
|
+
license="mit",
|
|
120
|
+
size_in_GB=1.20,
|
|
121
|
+
sources=ModelSource(hf="qdrant/gte-large-onnx"),
|
|
122
|
+
model_file="model.onnx",
|
|
123
|
+
),
|
|
124
|
+
]
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class PooledNormalizedEmbedding(PooledEmbedding):
|
|
128
|
+
@classmethod
|
|
129
|
+
def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
|
|
130
|
+
return PooledNormalizedEmbeddingWorker
|
|
131
|
+
|
|
132
|
+
@classmethod
|
|
133
|
+
def _list_supported_models(cls) -> list[DenseModelDescription]:
|
|
134
|
+
"""Lists the supported models.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
|
|
138
|
+
"""
|
|
139
|
+
return supported_pooled_normalized_models
|
|
140
|
+
|
|
141
|
+
def _post_process_onnx_output(
|
|
142
|
+
self, output: OnnxOutputContext, **kwargs: Any
|
|
143
|
+
) -> Iterable[NumpyArray]:
|
|
144
|
+
if output.attention_mask is None:
|
|
145
|
+
raise ValueError("attention_mask must be provided for document post-processing")
|
|
146
|
+
|
|
147
|
+
embeddings = output.model_output
|
|
148
|
+
attn_mask = output.attention_mask
|
|
149
|
+
return normalize(self.mean_pooling(embeddings, attn_mask))
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
class PooledNormalizedEmbeddingWorker(OnnxTextEmbeddingWorker):
|
|
153
|
+
def init_embedding(
|
|
154
|
+
self,
|
|
155
|
+
model_name: str,
|
|
156
|
+
cache_dir: str,
|
|
157
|
+
**kwargs: Any,
|
|
158
|
+
) -> OnnxTextEmbedding:
|
|
159
|
+
return PooledNormalizedEmbedding(
|
|
160
|
+
model_name=model_name,
|
|
161
|
+
cache_dir=cache_dir,
|
|
162
|
+
threads=1,
|
|
163
|
+
**kwargs,
|
|
164
|
+
)
|
|
@@ -0,0 +1,228 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from typing import Any, Iterable, Sequence, Type
|
|
3
|
+
from dataclasses import asdict
|
|
4
|
+
|
|
5
|
+
from fastembed.common.types import NumpyArray, OnnxProvider, Device
|
|
6
|
+
from fastembed.text.clip_embedding import CLIPOnnxEmbedding
|
|
7
|
+
from fastembed.text.custom_text_embedding import CustomTextEmbedding
|
|
8
|
+
from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
|
|
9
|
+
from fastembed.text.pooled_embedding import PooledEmbedding
|
|
10
|
+
from fastembed.text.multitask_embedding import JinaEmbeddingV3
|
|
11
|
+
from fastembed.text.onnx_embedding import OnnxTextEmbedding
|
|
12
|
+
from fastembed.text.text_embedding_base import TextEmbeddingBase
|
|
13
|
+
from fastembed.common.model_description import DenseModelDescription, ModelSource, PoolingType
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TextEmbedding(TextEmbeddingBase):
|
|
17
|
+
EMBEDDINGS_REGISTRY: list[Type[TextEmbeddingBase]] = [
|
|
18
|
+
OnnxTextEmbedding,
|
|
19
|
+
CLIPOnnxEmbedding,
|
|
20
|
+
PooledNormalizedEmbedding,
|
|
21
|
+
PooledEmbedding,
|
|
22
|
+
JinaEmbeddingV3,
|
|
23
|
+
CustomTextEmbedding,
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def list_supported_models(cls) -> list[dict[str, Any]]:
|
|
28
|
+
"""Lists the supported models.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
list[dict[str, Any]]: A list of dictionaries containing the model information.
|
|
32
|
+
"""
|
|
33
|
+
return [asdict(model) for model in cls._list_supported_models()]
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
def _list_supported_models(cls) -> list[DenseModelDescription]:
|
|
37
|
+
result: list[DenseModelDescription] = []
|
|
38
|
+
for embedding in cls.EMBEDDINGS_REGISTRY:
|
|
39
|
+
result.extend(embedding._list_supported_models())
|
|
40
|
+
return result
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def add_custom_model(
|
|
44
|
+
cls,
|
|
45
|
+
model: str,
|
|
46
|
+
pooling: PoolingType,
|
|
47
|
+
normalization: bool,
|
|
48
|
+
sources: ModelSource,
|
|
49
|
+
dim: int,
|
|
50
|
+
model_file: str = "onnx/model.onnx",
|
|
51
|
+
description: str = "",
|
|
52
|
+
license: str = "",
|
|
53
|
+
size_in_gb: float = 0.0,
|
|
54
|
+
additional_files: list[str] | None = None,
|
|
55
|
+
) -> None:
|
|
56
|
+
registered_models = cls._list_supported_models()
|
|
57
|
+
for registered_model in registered_models:
|
|
58
|
+
if model.lower() == registered_model.model.lower():
|
|
59
|
+
raise ValueError(
|
|
60
|
+
f"Model {model} is already registered in TextEmbedding, if you still want to add this model, "
|
|
61
|
+
f"please use another model name"
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
CustomTextEmbedding.add_model(
|
|
65
|
+
DenseModelDescription(
|
|
66
|
+
model=model,
|
|
67
|
+
sources=sources,
|
|
68
|
+
dim=dim,
|
|
69
|
+
model_file=model_file,
|
|
70
|
+
description=description,
|
|
71
|
+
license=license,
|
|
72
|
+
size_in_GB=size_in_gb,
|
|
73
|
+
additional_files=additional_files or [],
|
|
74
|
+
),
|
|
75
|
+
pooling=pooling,
|
|
76
|
+
normalization=normalization,
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
model_name: str = "BAAI/bge-small-en-v1.5",
|
|
82
|
+
cache_dir: str | None = None,
|
|
83
|
+
threads: int | None = None,
|
|
84
|
+
providers: Sequence[OnnxProvider] | None = None,
|
|
85
|
+
cuda: bool | Device = Device.AUTO,
|
|
86
|
+
device_ids: list[int] | None = None,
|
|
87
|
+
lazy_load: bool = False,
|
|
88
|
+
**kwargs: Any,
|
|
89
|
+
):
|
|
90
|
+
super().__init__(model_name, cache_dir, threads, **kwargs)
|
|
91
|
+
if model_name.lower() == "nomic-ai/nomic-embed-text-v1.5-Q".lower():
|
|
92
|
+
warnings.warn(
|
|
93
|
+
"The model 'nomic-ai/nomic-embed-text-v1.5-Q' has been updated on HuggingFace. Please review "
|
|
94
|
+
"the latest documentation on HF and release notes to ensure compatibility with your workflow. ",
|
|
95
|
+
UserWarning,
|
|
96
|
+
stacklevel=2,
|
|
97
|
+
)
|
|
98
|
+
if model_name.lower() in {
|
|
99
|
+
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2".lower(),
|
|
100
|
+
"thenlper/gte-large".lower(),
|
|
101
|
+
"intfloat/multilingual-e5-large".lower(),
|
|
102
|
+
"sentence-transformers/paraphrase-multilingual-mpnet-base-v2".lower(),
|
|
103
|
+
}:
|
|
104
|
+
warnings.warn(
|
|
105
|
+
f"The model {model_name} now uses mean pooling instead of CLS embedding. "
|
|
106
|
+
f"In order to preserve the previous behaviour, consider either pinning fastembed version to 0.5.1 or "
|
|
107
|
+
"using `add_custom_model` functionality.",
|
|
108
|
+
UserWarning,
|
|
109
|
+
stacklevel=2,
|
|
110
|
+
)
|
|
111
|
+
for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY:
|
|
112
|
+
supported_models = EMBEDDING_MODEL_TYPE._list_supported_models()
|
|
113
|
+
if any(model_name.lower() == model.model.lower() for model in supported_models):
|
|
114
|
+
self.model = EMBEDDING_MODEL_TYPE(
|
|
115
|
+
model_name=model_name,
|
|
116
|
+
cache_dir=cache_dir,
|
|
117
|
+
threads=threads,
|
|
118
|
+
providers=providers,
|
|
119
|
+
cuda=cuda,
|
|
120
|
+
device_ids=device_ids,
|
|
121
|
+
lazy_load=lazy_load,
|
|
122
|
+
**kwargs,
|
|
123
|
+
)
|
|
124
|
+
return
|
|
125
|
+
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"Model {model_name} is not supported in TextEmbedding. "
|
|
128
|
+
"Please check the supported models using `TextEmbedding.list_supported_models()`"
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def embedding_size(self) -> int:
|
|
133
|
+
"""Get the embedding size of the current model"""
|
|
134
|
+
if self._embedding_size is None:
|
|
135
|
+
self._embedding_size = self.get_embedding_size(self.model_name)
|
|
136
|
+
return self._embedding_size
|
|
137
|
+
|
|
138
|
+
@classmethod
|
|
139
|
+
def get_embedding_size(cls, model_name: str) -> int:
|
|
140
|
+
"""Get the embedding size of the passed model
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
model_name (str): The name of the model to get embedding size for.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
int: The size of the embedding.
|
|
147
|
+
|
|
148
|
+
Raises:
|
|
149
|
+
ValueError: If the model name is not found in the supported models.
|
|
150
|
+
"""
|
|
151
|
+
descriptions = cls._list_supported_models()
|
|
152
|
+
embedding_size: int | None = None
|
|
153
|
+
for description in descriptions:
|
|
154
|
+
if description.model.lower() == model_name.lower():
|
|
155
|
+
embedding_size = description.dim
|
|
156
|
+
break
|
|
157
|
+
if embedding_size is None:
|
|
158
|
+
model_names = [description.model for description in descriptions]
|
|
159
|
+
raise ValueError(
|
|
160
|
+
f"Embedding size for model {model_name} was None. "
|
|
161
|
+
f"Available model names: {model_names}"
|
|
162
|
+
)
|
|
163
|
+
return embedding_size
|
|
164
|
+
|
|
165
|
+
def embed(
|
|
166
|
+
self,
|
|
167
|
+
documents: str | Iterable[str],
|
|
168
|
+
batch_size: int = 256,
|
|
169
|
+
parallel: int | None = None,
|
|
170
|
+
**kwargs: Any,
|
|
171
|
+
) -> Iterable[NumpyArray]:
|
|
172
|
+
"""
|
|
173
|
+
Encode a list of documents into list of embeddings.
|
|
174
|
+
We use mean pooling with attention so that the model can handle variable-length inputs.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
documents: Iterator of documents or single document to embed
|
|
178
|
+
batch_size: Batch size for encoding -- higher values will use more memory, but be faster
|
|
179
|
+
parallel:
|
|
180
|
+
If > 1, data-parallel encoding will be used, recommended for offline encoding of large datasets.
|
|
181
|
+
If 0, use all available cores.
|
|
182
|
+
If None, don't use data-parallel processing, use default onnxruntime threading instead.
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
List of embeddings, one per document
|
|
186
|
+
"""
|
|
187
|
+
yield from self.model.embed(documents, batch_size, parallel, **kwargs)
|
|
188
|
+
|
|
189
|
+
def query_embed(self, query: str | Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
|
|
190
|
+
"""
|
|
191
|
+
Embeds queries
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
Iterable[NumpyArray]: The embeddings.
|
|
198
|
+
"""
|
|
199
|
+
# This is model-specific, so that different models can have specialized implementations
|
|
200
|
+
yield from self.model.query_embed(query, **kwargs)
|
|
201
|
+
|
|
202
|
+
def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
|
|
203
|
+
"""
|
|
204
|
+
Embeds a list of text passages into a list of embeddings.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
texts (Iterable[str]): The list of texts to embed.
|
|
208
|
+
**kwargs: Additional keyword argument to pass to the embed method.
|
|
209
|
+
|
|
210
|
+
Yields:
|
|
211
|
+
Iterable[SparseEmbedding]: The sparse embeddings.
|
|
212
|
+
"""
|
|
213
|
+
# This is model-specific, so that different models can have specialized implementations
|
|
214
|
+
yield from self.model.passage_embed(texts, **kwargs)
|
|
215
|
+
|
|
216
|
+
def token_count(
|
|
217
|
+
self, texts: str | Iterable[str], batch_size: int = 1024, **kwargs: Any
|
|
218
|
+
) -> int:
|
|
219
|
+
"""Returns the number of tokens in the texts.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
texts (str | Iterable[str]): The list of texts to embed.
|
|
223
|
+
batch_size (int): Batch size for encoding
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
int: Sum of number of tokens in the texts.
|
|
227
|
+
"""
|
|
228
|
+
return self.model.token_count(texts, batch_size=batch_size, **kwargs)
|
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
from typing import Iterable, Any
|
|
2
|
+
|
|
3
|
+
from fastembed.common.model_description import DenseModelDescription
|
|
4
|
+
from fastembed.common.types import NumpyArray
|
|
5
|
+
from fastembed.common.model_management import ModelManagement
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TextEmbeddingBase(ModelManagement[DenseModelDescription]):
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
model_name: str,
|
|
12
|
+
cache_dir: str | None = None,
|
|
13
|
+
threads: int | None = None,
|
|
14
|
+
**kwargs: Any,
|
|
15
|
+
):
|
|
16
|
+
self.model_name = model_name
|
|
17
|
+
self.cache_dir = cache_dir
|
|
18
|
+
self.threads = threads
|
|
19
|
+
self._local_files_only = kwargs.pop("local_files_only", False)
|
|
20
|
+
self._embedding_size: int | None = None
|
|
21
|
+
|
|
22
|
+
def embed(
|
|
23
|
+
self,
|
|
24
|
+
documents: str | Iterable[str],
|
|
25
|
+
batch_size: int = 256,
|
|
26
|
+
parallel: int | None = None,
|
|
27
|
+
**kwargs: Any,
|
|
28
|
+
) -> Iterable[NumpyArray]:
|
|
29
|
+
raise NotImplementedError()
|
|
30
|
+
|
|
31
|
+
def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
|
|
32
|
+
"""
|
|
33
|
+
Embeds a list of text passages into a list of embeddings.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
texts (Iterable[str]): The list of texts to embed.
|
|
37
|
+
**kwargs: Additional keyword argument to pass to the embed method.
|
|
38
|
+
|
|
39
|
+
Yields:
|
|
40
|
+
Iterable[NumpyArray]: The embeddings.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
# This is model-specific, so that different models can have specialized implementations
|
|
44
|
+
yield from self.embed(texts, **kwargs)
|
|
45
|
+
|
|
46
|
+
def query_embed(self, query: str | Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
|
|
47
|
+
"""
|
|
48
|
+
Embeds queries
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Iterable[NumpyArray]: The embeddings.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
# This is model-specific, so that different models can have specialized implementations
|
|
58
|
+
if isinstance(query, str):
|
|
59
|
+
yield from self.embed([query], **kwargs)
|
|
60
|
+
else:
|
|
61
|
+
yield from self.embed(query, **kwargs)
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
def get_embedding_size(cls, model_name: str) -> int:
|
|
65
|
+
"""Returns embedding size of the passed model."""
|
|
66
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
67
|
+
|
|
68
|
+
@property
|
|
69
|
+
def embedding_size(self) -> int:
|
|
70
|
+
"""Returns embedding size for the current model"""
|
|
71
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
72
|
+
|
|
73
|
+
def token_count(self, texts: str | Iterable[str], **kwargs: Any) -> int:
|
|
74
|
+
"""Returns the number of tokens in the texts."""
|
|
75
|
+
raise NotImplementedError("Subclasses must implement this method")
|