fastembed-bio 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastembed/__init__.py +24 -0
- fastembed/bio/__init__.py +3 -0
- fastembed/bio/protein_embedding.py +456 -0
- fastembed/common/__init__.py +3 -0
- fastembed/common/model_description.py +52 -0
- fastembed/common/model_management.py +471 -0
- fastembed/common/onnx_model.py +188 -0
- fastembed/common/preprocessor_utils.py +84 -0
- fastembed/common/types.py +27 -0
- fastembed/common/utils.py +69 -0
- fastembed/embedding.py +24 -0
- fastembed/image/__init__.py +3 -0
- fastembed/image/image_embedding.py +135 -0
- fastembed/image/image_embedding_base.py +55 -0
- fastembed/image/onnx_embedding.py +217 -0
- fastembed/image/onnx_image_model.py +156 -0
- fastembed/image/transform/functional.py +221 -0
- fastembed/image/transform/operators.py +499 -0
- fastembed/late_interaction/__init__.py +5 -0
- fastembed/late_interaction/colbert.py +301 -0
- fastembed/late_interaction/jina_colbert.py +58 -0
- fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
- fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
- fastembed/late_interaction/token_embeddings.py +83 -0
- fastembed/late_interaction_multimodal/__init__.py +5 -0
- fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
- fastembed/late_interaction_multimodal/colpali.py +327 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
- fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
- fastembed/parallel_processor.py +253 -0
- fastembed/postprocess/__init__.py +3 -0
- fastembed/postprocess/muvera.py +362 -0
- fastembed/py.typed +1 -0
- fastembed/rerank/cross_encoder/__init__.py +3 -0
- fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
- fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
- fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
- fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
- fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
- fastembed/sparse/__init__.py +4 -0
- fastembed/sparse/bm25.py +359 -0
- fastembed/sparse/bm42.py +369 -0
- fastembed/sparse/minicoil.py +372 -0
- fastembed/sparse/sparse_embedding_base.py +90 -0
- fastembed/sparse/sparse_text_embedding.py +143 -0
- fastembed/sparse/splade_pp.py +196 -0
- fastembed/sparse/utils/minicoil_encoder.py +146 -0
- fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
- fastembed/sparse/utils/tokenizer.py +120 -0
- fastembed/sparse/utils/vocab_resolver.py +202 -0
- fastembed/text/__init__.py +3 -0
- fastembed/text/clip_embedding.py +56 -0
- fastembed/text/custom_text_embedding.py +97 -0
- fastembed/text/multitask_embedding.py +109 -0
- fastembed/text/onnx_embedding.py +353 -0
- fastembed/text/onnx_text_model.py +180 -0
- fastembed/text/pooled_embedding.py +136 -0
- fastembed/text/pooled_normalized_embedding.py +164 -0
- fastembed/text/text_embedding.py +228 -0
- fastembed/text/text_embedding_base.py +75 -0
- fastembed_bio-0.1.0.dist-info/METADATA +339 -0
- fastembed_bio-0.1.0.dist-info/RECORD +66 -0
- fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
- fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
- fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
# This code is a modified copy of the `NLTKWordTokenizer` class from `NLTK` library.
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SimpleTokenizer:
|
|
7
|
+
@staticmethod
|
|
8
|
+
def tokenize(text: str) -> list[str]:
|
|
9
|
+
text = re.sub(r"[^\w]", " ", text.lower())
|
|
10
|
+
text = re.sub(r"\s+", " ", text)
|
|
11
|
+
|
|
12
|
+
return text.strip().split()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class WordTokenizer:
|
|
16
|
+
"""The tokenizer is "destructive" such that the regexes applied will munge the
|
|
17
|
+
input string to a state beyond re-construction.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
# Starting quotes.
|
|
21
|
+
STARTING_QUOTES = [
|
|
22
|
+
(re.compile("([«“‘„]|[`]+)", re.U), r" \1 "),
|
|
23
|
+
(re.compile(r"^\""), r"``"),
|
|
24
|
+
(re.compile(r"(``)"), r" \1 "),
|
|
25
|
+
(re.compile(r"([ \(\[{<])(\"|\'{2})"), r"\1 `` "),
|
|
26
|
+
(re.compile(r"(?i)(\')(?!re|ve|ll|m|t|s|d|n)(\w)\b", re.U), r"\1 \2"),
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
# Ending quotes.
|
|
30
|
+
ENDING_QUOTES = [
|
|
31
|
+
(re.compile("([»”’])", re.U), r" \1 "),
|
|
32
|
+
(re.compile(r"''"), " '' "),
|
|
33
|
+
(re.compile(r'"'), " '' "),
|
|
34
|
+
(re.compile(r"([^' ])('[sS]|'[mM]|'[dD]|') "), r"\1 \2 "),
|
|
35
|
+
(re.compile(r"([^' ])('ll|'LL|'re|'RE|'ve|'VE|n't|N'T) "), r"\1 \2 "),
|
|
36
|
+
]
|
|
37
|
+
|
|
38
|
+
# Punctuation.
|
|
39
|
+
PUNCTUATION = [
|
|
40
|
+
(re.compile(r'([^\.])(\.)([\]\)}>"\'' "»”’ " r"]*)\s*$", re.U), r"\1 \2 \3 "),
|
|
41
|
+
(re.compile(r"([:,])([^\d])"), r" \1 \2"),
|
|
42
|
+
(re.compile(r"([:,])$"), r" \1 "),
|
|
43
|
+
(
|
|
44
|
+
re.compile(r"\.{2,}", re.U),
|
|
45
|
+
r" \g<0> ",
|
|
46
|
+
),
|
|
47
|
+
(re.compile(r"[;@#$%&]"), r" \g<0> "),
|
|
48
|
+
(
|
|
49
|
+
re.compile(r'([^\.])(\.)([\]\)}>"\']*)\s*$'),
|
|
50
|
+
r"\1 \2\3 ",
|
|
51
|
+
), # Handles the final period.
|
|
52
|
+
(re.compile(r"[?!]"), r" \g<0> "),
|
|
53
|
+
(re.compile(r"([^'])' "), r"\1 ' "),
|
|
54
|
+
(
|
|
55
|
+
re.compile(r"[*]", re.U),
|
|
56
|
+
r" \g<0> ",
|
|
57
|
+
),
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
# Pads parentheses
|
|
61
|
+
PARENS_BRACKETS = (re.compile(r"[\]\[\(\)\{\}\<\>]"), r" \g<0> ")
|
|
62
|
+
DOUBLE_DASHES = (re.compile(r"--"), r" -- ")
|
|
63
|
+
|
|
64
|
+
# List of contractions adapted from Robert MacIntyre's tokenizer.
|
|
65
|
+
CONTRACTIONS2 = [
|
|
66
|
+
re.compile(pattern)
|
|
67
|
+
for pattern in (
|
|
68
|
+
r"(?i)\b(can)(?#X)(not)\b",
|
|
69
|
+
r"(?i)\b(d)(?#X)('ye)\b",
|
|
70
|
+
r"(?i)\b(gim)(?#X)(me)\b",
|
|
71
|
+
r"(?i)\b(gon)(?#X)(na)\b",
|
|
72
|
+
r"(?i)\b(got)(?#X)(ta)\b",
|
|
73
|
+
r"(?i)\b(lem)(?#X)(me)\b",
|
|
74
|
+
r"(?i)\b(more)(?#X)('n)\b",
|
|
75
|
+
r"(?i)\b(wan)(?#X)(na)(?=\s)",
|
|
76
|
+
)
|
|
77
|
+
]
|
|
78
|
+
CONTRACTIONS3 = [
|
|
79
|
+
re.compile(pattern) for pattern in (r"(?i) ('t)(?#X)(is)\b", r"(?i) ('t)(?#X)(was)\b")
|
|
80
|
+
]
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def tokenize(cls, text: str) -> list[str]:
|
|
84
|
+
"""Return a tokenized copy of `text`.
|
|
85
|
+
|
|
86
|
+
>>> s = '''Good muffins cost $3.88 (roughly 3,36 euros)\nin New York.'''
|
|
87
|
+
>>> WordTokenizer().tokenize(s)
|
|
88
|
+
['Good', 'muffins', 'cost', '$', '3.88', '(', 'roughly', '3,36', 'euros', ')', 'in', 'New', 'York', '.']
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
text: The text to be tokenized.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
A list of tokens.
|
|
95
|
+
"""
|
|
96
|
+
for regexp, substitution in cls.STARTING_QUOTES:
|
|
97
|
+
text = regexp.sub(substitution, text)
|
|
98
|
+
|
|
99
|
+
for regexp, substitution in cls.PUNCTUATION:
|
|
100
|
+
text = regexp.sub(substitution, text)
|
|
101
|
+
|
|
102
|
+
# Handles parentheses.
|
|
103
|
+
regexp, substitution = cls.PARENS_BRACKETS
|
|
104
|
+
text = regexp.sub(substitution, text)
|
|
105
|
+
|
|
106
|
+
# Handles double dash.
|
|
107
|
+
regexp, substitution = cls.DOUBLE_DASHES
|
|
108
|
+
text = regexp.sub(substitution, text)
|
|
109
|
+
|
|
110
|
+
# add extra space to make things easier
|
|
111
|
+
text = " " + text + " "
|
|
112
|
+
|
|
113
|
+
for regexp, substitution in cls.ENDING_QUOTES:
|
|
114
|
+
text = regexp.sub(substitution, text)
|
|
115
|
+
|
|
116
|
+
for regexp in cls.CONTRACTIONS2:
|
|
117
|
+
text = regexp.sub(r" \1 \2 ", text)
|
|
118
|
+
for regexp in cls.CONTRACTIONS3:
|
|
119
|
+
text = regexp.sub(r" \1 \2 ", text)
|
|
120
|
+
return text.split()
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from typing import Iterable
|
|
3
|
+
|
|
4
|
+
from py_rust_stemmers import SnowballStemmer
|
|
5
|
+
import numpy as np
|
|
6
|
+
from tokenizers import Tokenizer
|
|
7
|
+
from numpy.typing import NDArray
|
|
8
|
+
|
|
9
|
+
from fastembed.common.types import NumpyArray
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class VocabTokenizerBase:
|
|
13
|
+
def tokenize(self, sentence: str) -> NumpyArray:
|
|
14
|
+
raise NotImplementedError()
|
|
15
|
+
|
|
16
|
+
def convert_ids_to_tokens(self, token_ids: NumpyArray) -> list[str]:
|
|
17
|
+
raise NotImplementedError()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class VocabTokenizer(VocabTokenizerBase):
|
|
21
|
+
def __init__(self, tokenizer: Tokenizer):
|
|
22
|
+
self.tokenizer = tokenizer
|
|
23
|
+
|
|
24
|
+
def tokenize(self, sentence: str) -> NumpyArray:
|
|
25
|
+
return np.array(self.tokenizer.encode(sentence).ids)
|
|
26
|
+
|
|
27
|
+
def convert_ids_to_tokens(self, token_ids: NumpyArray) -> list[str]:
|
|
28
|
+
return [self.tokenizer.id_to_token(token_id) for token_id in token_ids]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class VocabResolver:
|
|
32
|
+
def __init__(self, tokenizer: VocabTokenizerBase, stopwords: set[str], stemmer: SnowballStemmer):
|
|
33
|
+
# Word to id mapping
|
|
34
|
+
self.vocab: dict[str, int] = {}
|
|
35
|
+
# Id to word mapping
|
|
36
|
+
self.words: list[str] = []
|
|
37
|
+
# Lemma to word mapping
|
|
38
|
+
self.stem_mapping: dict[str, str] = {}
|
|
39
|
+
self.tokenizer: VocabTokenizerBase = tokenizer
|
|
40
|
+
self.stemmer = stemmer
|
|
41
|
+
self.stopwords: set[str] = stopwords
|
|
42
|
+
|
|
43
|
+
def tokenize(self, sentence: str) -> NumpyArray:
|
|
44
|
+
return self.tokenizer.tokenize(sentence)
|
|
45
|
+
|
|
46
|
+
def lookup_word(self, word_id: int) -> str:
|
|
47
|
+
if word_id == 0:
|
|
48
|
+
return "UNK"
|
|
49
|
+
return self.words[word_id - 1]
|
|
50
|
+
|
|
51
|
+
def convert_ids_to_tokens(self, token_ids: NumpyArray) -> list[str]:
|
|
52
|
+
return self.tokenizer.convert_ids_to_tokens(token_ids)
|
|
53
|
+
|
|
54
|
+
def vocab_size(self) -> int:
|
|
55
|
+
# We need +1 for UNK token
|
|
56
|
+
return len(self.vocab) + 1
|
|
57
|
+
|
|
58
|
+
def save_vocab(self, path: str) -> None:
|
|
59
|
+
with open(path, "w") as f:
|
|
60
|
+
for word in self.words:
|
|
61
|
+
f.write(word + "\n")
|
|
62
|
+
|
|
63
|
+
def save_json_vocab(self, path: str) -> None:
|
|
64
|
+
import json
|
|
65
|
+
|
|
66
|
+
with open(path, "w") as f:
|
|
67
|
+
json.dump({"vocab": self.words, "stem_mapping": self.stem_mapping}, f, indent=2)
|
|
68
|
+
|
|
69
|
+
def load_json_vocab(self, path: str) -> None:
|
|
70
|
+
import json
|
|
71
|
+
|
|
72
|
+
with open(path, "r") as f:
|
|
73
|
+
data = json.load(f)
|
|
74
|
+
self.words = data["vocab"]
|
|
75
|
+
self.vocab = {word: idx + 1 for idx, word in enumerate(self.words)}
|
|
76
|
+
self.stem_mapping = data["stem_mapping"]
|
|
77
|
+
|
|
78
|
+
def add_word(self, word: str) -> None:
|
|
79
|
+
if word not in self.vocab:
|
|
80
|
+
self.vocab[word] = len(self.vocab) + 1
|
|
81
|
+
self.words.append(word)
|
|
82
|
+
stem = self.stemmer.stem_word(word)
|
|
83
|
+
if stem not in self.stem_mapping:
|
|
84
|
+
self.stem_mapping[stem] = word
|
|
85
|
+
else:
|
|
86
|
+
existing_word = self.stem_mapping[stem]
|
|
87
|
+
if len(existing_word) > len(word):
|
|
88
|
+
# Prefer shorter words for the same stem
|
|
89
|
+
# Example: "swim" is preferred over "swimming"
|
|
90
|
+
self.stem_mapping[stem] = word
|
|
91
|
+
|
|
92
|
+
def load_vocab(self, path: str) -> None:
|
|
93
|
+
with open(path, "r") as f:
|
|
94
|
+
for line in f:
|
|
95
|
+
self.add_word(line.strip())
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def _reconstruct_bpe(
|
|
99
|
+
cls, bpe_tokens: Iterable[tuple[int, str]]
|
|
100
|
+
) -> list[tuple[str, list[int]]]:
|
|
101
|
+
result: list[tuple[str, list[int]]] = []
|
|
102
|
+
acc: str = ""
|
|
103
|
+
acc_idx: list[int] = []
|
|
104
|
+
|
|
105
|
+
continuing_subword_prefix = "##"
|
|
106
|
+
continuing_subword_prefix_len = len(continuing_subword_prefix)
|
|
107
|
+
|
|
108
|
+
for idx, token in bpe_tokens:
|
|
109
|
+
if token.startswith(continuing_subword_prefix):
|
|
110
|
+
acc += token[continuing_subword_prefix_len:]
|
|
111
|
+
acc_idx.append(idx)
|
|
112
|
+
else:
|
|
113
|
+
if acc:
|
|
114
|
+
result.append((acc, acc_idx))
|
|
115
|
+
acc_idx = []
|
|
116
|
+
acc = token
|
|
117
|
+
acc_idx.append(idx)
|
|
118
|
+
|
|
119
|
+
if acc:
|
|
120
|
+
result.append((acc, acc_idx))
|
|
121
|
+
return result
|
|
122
|
+
|
|
123
|
+
def resolve_tokens(
|
|
124
|
+
self, token_ids: NDArray[np.int64]
|
|
125
|
+
) -> tuple[NDArray[np.int64], dict[int, int], dict[str, int], dict[str, list[str]]]:
|
|
126
|
+
"""
|
|
127
|
+
Mark known tokens (including composed tokens) with vocab ids.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
token_ids: (seq_len) - list of ids of tokens
|
|
131
|
+
Example:
|
|
132
|
+
[
|
|
133
|
+
101, 3897, 19332, 12718, 23348,
|
|
134
|
+
1010, 1996, 7151, 2296, 4845,
|
|
135
|
+
2359, 2005, 4234, 1010, 4332,
|
|
136
|
+
2871, 3191, 2062, 102
|
|
137
|
+
]
|
|
138
|
+
|
|
139
|
+
returns:
|
|
140
|
+
- token_ids with vocab ids
|
|
141
|
+
[
|
|
142
|
+
0, 151, 151, 0, 0,
|
|
143
|
+
912, 0, 0, 0, 332,
|
|
144
|
+
332, 332, 0, 7121, 191,
|
|
145
|
+
0, 0, 332, 0
|
|
146
|
+
]
|
|
147
|
+
- counts of each token
|
|
148
|
+
{
|
|
149
|
+
151: 1,
|
|
150
|
+
332: 3,
|
|
151
|
+
7121: 1,
|
|
152
|
+
191: 1,
|
|
153
|
+
912: 1
|
|
154
|
+
}
|
|
155
|
+
- oov counts of each token
|
|
156
|
+
{
|
|
157
|
+
"the": 1,
|
|
158
|
+
"a": 1,
|
|
159
|
+
"[CLS]": 1,
|
|
160
|
+
"[SEP]": 1,
|
|
161
|
+
...
|
|
162
|
+
}
|
|
163
|
+
- forms of each token
|
|
164
|
+
{
|
|
165
|
+
"hello": ["hello"],
|
|
166
|
+
"world": ["worlds", "world", "worlding"],
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
"""
|
|
170
|
+
tokens = self.convert_ids_to_tokens(token_ids)
|
|
171
|
+
tokens_mapping = self._reconstruct_bpe(enumerate(tokens))
|
|
172
|
+
|
|
173
|
+
counts: dict[int, int] = defaultdict(int)
|
|
174
|
+
oov_count: dict[str, int] = defaultdict(int)
|
|
175
|
+
|
|
176
|
+
forms: dict[str, list[str]] = defaultdict(list)
|
|
177
|
+
|
|
178
|
+
for token, mapped_token_ids in tokens_mapping:
|
|
179
|
+
vocab_id = 0
|
|
180
|
+
if token in self.stopwords:
|
|
181
|
+
vocab_id = 0
|
|
182
|
+
elif token in self.vocab:
|
|
183
|
+
vocab_id = self.vocab[token]
|
|
184
|
+
forms[token].append(token)
|
|
185
|
+
elif token in self.stem_mapping:
|
|
186
|
+
vocab_id = self.vocab[self.stem_mapping[token]]
|
|
187
|
+
forms[self.stem_mapping[token]].append(token)
|
|
188
|
+
else:
|
|
189
|
+
stem = self.stemmer.stem_word(token)
|
|
190
|
+
if stem in self.stem_mapping:
|
|
191
|
+
vocab_id = self.vocab[self.stem_mapping[stem]]
|
|
192
|
+
forms[self.stem_mapping[stem]].append(token)
|
|
193
|
+
|
|
194
|
+
for token_id in mapped_token_ids:
|
|
195
|
+
token_ids[token_id] = vocab_id
|
|
196
|
+
|
|
197
|
+
if vocab_id == 0:
|
|
198
|
+
oov_count[token] += 1
|
|
199
|
+
else:
|
|
200
|
+
counts[vocab_id] += 1
|
|
201
|
+
return token_ids, counts, oov_count, forms
|
|
202
|
+
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from typing import Any, Iterable, Type
|
|
2
|
+
|
|
3
|
+
from fastembed.common.types import NumpyArray
|
|
4
|
+
from fastembed.common.onnx_model import OnnxOutputContext
|
|
5
|
+
from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker
|
|
6
|
+
from fastembed.common.model_description import DenseModelDescription, ModelSource
|
|
7
|
+
|
|
8
|
+
supported_clip_models: list[DenseModelDescription] = [
|
|
9
|
+
DenseModelDescription(
|
|
10
|
+
model="Qdrant/clip-ViT-B-32-text",
|
|
11
|
+
dim=512,
|
|
12
|
+
description=(
|
|
13
|
+
"Text embeddings, Multimodal (text&image), English, 77 input tokens truncation, "
|
|
14
|
+
"Prefixes for queries/documents: not necessary, 2021 year"
|
|
15
|
+
),
|
|
16
|
+
license="mit",
|
|
17
|
+
size_in_GB=0.25,
|
|
18
|
+
sources=ModelSource(hf="Qdrant/clip-ViT-B-32-text"),
|
|
19
|
+
model_file="model.onnx",
|
|
20
|
+
),
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class CLIPOnnxEmbedding(OnnxTextEmbedding):
|
|
25
|
+
@classmethod
|
|
26
|
+
def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
|
|
27
|
+
return CLIPEmbeddingWorker
|
|
28
|
+
|
|
29
|
+
@classmethod
|
|
30
|
+
def _list_supported_models(cls) -> list[DenseModelDescription]:
|
|
31
|
+
"""Lists the supported models.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
|
|
35
|
+
"""
|
|
36
|
+
return supported_clip_models
|
|
37
|
+
|
|
38
|
+
def _post_process_onnx_output(
|
|
39
|
+
self, output: OnnxOutputContext, **kwargs: Any
|
|
40
|
+
) -> Iterable[NumpyArray]:
|
|
41
|
+
return output.model_output
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class CLIPEmbeddingWorker(OnnxTextEmbeddingWorker):
|
|
45
|
+
def init_embedding(
|
|
46
|
+
self,
|
|
47
|
+
model_name: str,
|
|
48
|
+
cache_dir: str,
|
|
49
|
+
**kwargs: Any,
|
|
50
|
+
) -> OnnxTextEmbedding:
|
|
51
|
+
return CLIPOnnxEmbedding(
|
|
52
|
+
model_name=model_name,
|
|
53
|
+
cache_dir=cache_dir,
|
|
54
|
+
threads=1,
|
|
55
|
+
**kwargs,
|
|
56
|
+
)
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from typing import Sequence, Any, Iterable
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
from numpy.typing import NDArray
|
|
6
|
+
|
|
7
|
+
from fastembed.common import OnnxProvider
|
|
8
|
+
from fastembed.common.model_description import (
|
|
9
|
+
PoolingType,
|
|
10
|
+
DenseModelDescription,
|
|
11
|
+
)
|
|
12
|
+
from fastembed.common.onnx_model import OnnxOutputContext
|
|
13
|
+
from fastembed.common.types import NumpyArray, Device
|
|
14
|
+
from fastembed.common.utils import normalize, mean_pooling
|
|
15
|
+
from fastembed.text.onnx_embedding import OnnxTextEmbedding
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass(frozen=True)
|
|
19
|
+
class PostprocessingConfig:
|
|
20
|
+
pooling: PoolingType
|
|
21
|
+
normalization: bool
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class CustomTextEmbedding(OnnxTextEmbedding):
|
|
25
|
+
SUPPORTED_MODELS: list[DenseModelDescription] = []
|
|
26
|
+
POSTPROCESSING_MAPPING: dict[str, PostprocessingConfig] = {}
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
model_name: str,
|
|
31
|
+
cache_dir: str | None = None,
|
|
32
|
+
threads: int | None = None,
|
|
33
|
+
providers: Sequence[OnnxProvider] | None = None,
|
|
34
|
+
cuda: bool | Device = Device.AUTO,
|
|
35
|
+
device_ids: list[int] | None = None,
|
|
36
|
+
lazy_load: bool = False,
|
|
37
|
+
device_id: int | None = None,
|
|
38
|
+
specific_model_path: str | None = None,
|
|
39
|
+
**kwargs: Any,
|
|
40
|
+
):
|
|
41
|
+
super().__init__(
|
|
42
|
+
model_name=model_name,
|
|
43
|
+
cache_dir=cache_dir,
|
|
44
|
+
threads=threads,
|
|
45
|
+
providers=providers,
|
|
46
|
+
cuda=cuda,
|
|
47
|
+
device_ids=device_ids,
|
|
48
|
+
lazy_load=lazy_load,
|
|
49
|
+
device_id=device_id,
|
|
50
|
+
specific_model_path=specific_model_path,
|
|
51
|
+
**kwargs,
|
|
52
|
+
)
|
|
53
|
+
self._pooling = self.POSTPROCESSING_MAPPING[model_name].pooling
|
|
54
|
+
self._normalization = self.POSTPROCESSING_MAPPING[model_name].normalization
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def _list_supported_models(cls) -> list[DenseModelDescription]:
|
|
58
|
+
return cls.SUPPORTED_MODELS
|
|
59
|
+
|
|
60
|
+
def _post_process_onnx_output(
|
|
61
|
+
self, output: OnnxOutputContext, **kwargs: Any
|
|
62
|
+
) -> Iterable[NumpyArray]:
|
|
63
|
+
return self._normalize(self._pool(output.model_output, output.attention_mask))
|
|
64
|
+
|
|
65
|
+
def _pool(
|
|
66
|
+
self, embeddings: NumpyArray, attention_mask: NDArray[np.int64] | None = None
|
|
67
|
+
) -> NumpyArray:
|
|
68
|
+
if self._pooling == PoolingType.CLS:
|
|
69
|
+
return embeddings[:, 0]
|
|
70
|
+
|
|
71
|
+
if self._pooling == PoolingType.MEAN:
|
|
72
|
+
if attention_mask is None:
|
|
73
|
+
raise ValueError("attention_mask must be provided for mean pooling")
|
|
74
|
+
return mean_pooling(embeddings, attention_mask)
|
|
75
|
+
|
|
76
|
+
if self._pooling == PoolingType.DISABLED:
|
|
77
|
+
return embeddings
|
|
78
|
+
|
|
79
|
+
raise ValueError(
|
|
80
|
+
f"Unsupported pooling type {self._pooling}. "
|
|
81
|
+
f"Supported types are: {PoolingType.CLS}, {PoolingType.MEAN}, {PoolingType.DISABLED}."
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def _normalize(self, embeddings: NumpyArray) -> NumpyArray:
|
|
85
|
+
return normalize(embeddings) if self._normalization else embeddings
|
|
86
|
+
|
|
87
|
+
@classmethod
|
|
88
|
+
def add_model(
|
|
89
|
+
cls,
|
|
90
|
+
model_description: DenseModelDescription,
|
|
91
|
+
pooling: PoolingType,
|
|
92
|
+
normalization: bool,
|
|
93
|
+
) -> None:
|
|
94
|
+
cls.SUPPORTED_MODELS.append(model_description)
|
|
95
|
+
cls.POSTPROCESSING_MAPPING[model_description.model] = PostprocessingConfig(
|
|
96
|
+
pooling=pooling, normalization=normalization
|
|
97
|
+
)
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
from typing import Any, Type, Iterable
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from fastembed.common.onnx_model import OnnxOutputContext
|
|
7
|
+
from fastembed.common.types import NumpyArray
|
|
8
|
+
from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
|
|
9
|
+
from fastembed.text.onnx_embedding import OnnxTextEmbeddingWorker
|
|
10
|
+
from fastembed.common.model_description import DenseModelDescription, ModelSource
|
|
11
|
+
|
|
12
|
+
supported_multitask_models: list[DenseModelDescription] = [
|
|
13
|
+
DenseModelDescription(
|
|
14
|
+
model="jinaai/jina-embeddings-v3",
|
|
15
|
+
dim=1024,
|
|
16
|
+
tasks={
|
|
17
|
+
"retrieval.query": 0,
|
|
18
|
+
"retrieval.passage": 1,
|
|
19
|
+
"separation": 2,
|
|
20
|
+
"classification": 3,
|
|
21
|
+
"text-matching": 4,
|
|
22
|
+
},
|
|
23
|
+
description=(
|
|
24
|
+
"Multi-task unimodal (text) embedding model, multi-lingual (~100), "
|
|
25
|
+
"1024 tokens truncation, and 8192 sequence length. Prefixes for queries/documents: not necessary, 2024 year."
|
|
26
|
+
),
|
|
27
|
+
license="cc-by-nc-4.0",
|
|
28
|
+
size_in_GB=2.29,
|
|
29
|
+
sources=ModelSource(hf="jinaai/jina-embeddings-v3"),
|
|
30
|
+
model_file="onnx/model.onnx",
|
|
31
|
+
additional_files=["onnx/model.onnx_data"],
|
|
32
|
+
),
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class Task(int, Enum):
|
|
37
|
+
RETRIEVAL_QUERY = 0
|
|
38
|
+
RETRIEVAL_PASSAGE = 1
|
|
39
|
+
SEPARATION = 2
|
|
40
|
+
CLASSIFICATION = 3
|
|
41
|
+
TEXT_MATCHING = 4
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class JinaEmbeddingV3(PooledNormalizedEmbedding):
|
|
45
|
+
PASSAGE_TASK = Task.RETRIEVAL_PASSAGE
|
|
46
|
+
QUERY_TASK = Task.RETRIEVAL_QUERY
|
|
47
|
+
|
|
48
|
+
def __init__(self, *args: Any, task_id: int | None = None, **kwargs: Any):
|
|
49
|
+
super().__init__(*args, **kwargs)
|
|
50
|
+
self.default_task_id: Task | int = task_id if task_id is not None else self.PASSAGE_TASK
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
|
|
54
|
+
return JinaEmbeddingV3Worker
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def _list_supported_models(cls) -> list[DenseModelDescription]:
|
|
58
|
+
return supported_multitask_models
|
|
59
|
+
|
|
60
|
+
def _preprocess_onnx_input(
|
|
61
|
+
self,
|
|
62
|
+
onnx_input: dict[str, NumpyArray],
|
|
63
|
+
task_id: int | Task | None = None,
|
|
64
|
+
**kwargs: Any,
|
|
65
|
+
) -> dict[str, NumpyArray]:
|
|
66
|
+
if task_id is None:
|
|
67
|
+
raise ValueError(f"task_id must be provided for JinaEmbeddingV3, got <{task_id}>")
|
|
68
|
+
onnx_input["task_id"] = np.array(task_id, dtype=np.int64)
|
|
69
|
+
return onnx_input
|
|
70
|
+
|
|
71
|
+
def embed(
|
|
72
|
+
self,
|
|
73
|
+
documents: str | Iterable[str],
|
|
74
|
+
batch_size: int = 256,
|
|
75
|
+
parallel: int | None = None,
|
|
76
|
+
task_id: int | None = None,
|
|
77
|
+
**kwargs: Any,
|
|
78
|
+
) -> Iterable[NumpyArray]:
|
|
79
|
+
task_id = (
|
|
80
|
+
task_id if task_id is not None else self.default_task_id
|
|
81
|
+
) # required for multiprocessing
|
|
82
|
+
yield from super().embed(documents, batch_size, parallel, task_id=task_id, **kwargs)
|
|
83
|
+
|
|
84
|
+
def query_embed(self, query: str | Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
|
|
85
|
+
yield from super().embed(query, task_id=self.QUERY_TASK, **kwargs)
|
|
86
|
+
|
|
87
|
+
def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
|
|
88
|
+
yield from super().embed(texts, task_id=self.PASSAGE_TASK, **kwargs)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class JinaEmbeddingV3Worker(OnnxTextEmbeddingWorker):
|
|
92
|
+
def init_embedding(
|
|
93
|
+
self,
|
|
94
|
+
model_name: str,
|
|
95
|
+
cache_dir: str,
|
|
96
|
+
**kwargs: Any,
|
|
97
|
+
) -> JinaEmbeddingV3:
|
|
98
|
+
return JinaEmbeddingV3(
|
|
99
|
+
model_name=model_name,
|
|
100
|
+
cache_dir=cache_dir,
|
|
101
|
+
threads=1,
|
|
102
|
+
**kwargs,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, OnnxOutputContext]]:
|
|
106
|
+
self.model: JinaEmbeddingV3 # mypy complaints `self.model` does not have `default_task_id`
|
|
107
|
+
for idx, batch in items:
|
|
108
|
+
onnx_output = self.model.onnx_embed(batch, task_id=self.model.default_task_id)
|
|
109
|
+
yield idx, onnx_output
|