fastembed-bio 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. fastembed/__init__.py +24 -0
  2. fastembed/bio/__init__.py +3 -0
  3. fastembed/bio/protein_embedding.py +456 -0
  4. fastembed/common/__init__.py +3 -0
  5. fastembed/common/model_description.py +52 -0
  6. fastembed/common/model_management.py +471 -0
  7. fastembed/common/onnx_model.py +188 -0
  8. fastembed/common/preprocessor_utils.py +84 -0
  9. fastembed/common/types.py +27 -0
  10. fastembed/common/utils.py +69 -0
  11. fastembed/embedding.py +24 -0
  12. fastembed/image/__init__.py +3 -0
  13. fastembed/image/image_embedding.py +135 -0
  14. fastembed/image/image_embedding_base.py +55 -0
  15. fastembed/image/onnx_embedding.py +217 -0
  16. fastembed/image/onnx_image_model.py +156 -0
  17. fastembed/image/transform/functional.py +221 -0
  18. fastembed/image/transform/operators.py +499 -0
  19. fastembed/late_interaction/__init__.py +5 -0
  20. fastembed/late_interaction/colbert.py +301 -0
  21. fastembed/late_interaction/jina_colbert.py +58 -0
  22. fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
  23. fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
  24. fastembed/late_interaction/token_embeddings.py +83 -0
  25. fastembed/late_interaction_multimodal/__init__.py +5 -0
  26. fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
  27. fastembed/late_interaction_multimodal/colpali.py +327 -0
  28. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
  29. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
  30. fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
  31. fastembed/parallel_processor.py +253 -0
  32. fastembed/postprocess/__init__.py +3 -0
  33. fastembed/postprocess/muvera.py +362 -0
  34. fastembed/py.typed +1 -0
  35. fastembed/rerank/cross_encoder/__init__.py +3 -0
  36. fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
  37. fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
  38. fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
  39. fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
  40. fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
  41. fastembed/sparse/__init__.py +4 -0
  42. fastembed/sparse/bm25.py +359 -0
  43. fastembed/sparse/bm42.py +369 -0
  44. fastembed/sparse/minicoil.py +372 -0
  45. fastembed/sparse/sparse_embedding_base.py +90 -0
  46. fastembed/sparse/sparse_text_embedding.py +143 -0
  47. fastembed/sparse/splade_pp.py +196 -0
  48. fastembed/sparse/utils/minicoil_encoder.py +146 -0
  49. fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
  50. fastembed/sparse/utils/tokenizer.py +120 -0
  51. fastembed/sparse/utils/vocab_resolver.py +202 -0
  52. fastembed/text/__init__.py +3 -0
  53. fastembed/text/clip_embedding.py +56 -0
  54. fastembed/text/custom_text_embedding.py +97 -0
  55. fastembed/text/multitask_embedding.py +109 -0
  56. fastembed/text/onnx_embedding.py +353 -0
  57. fastembed/text/onnx_text_model.py +180 -0
  58. fastembed/text/pooled_embedding.py +136 -0
  59. fastembed/text/pooled_normalized_embedding.py +164 -0
  60. fastembed/text/text_embedding.py +228 -0
  61. fastembed/text/text_embedding_base.py +75 -0
  62. fastembed_bio-0.1.0.dist-info/METADATA +339 -0
  63. fastembed_bio-0.1.0.dist-info/RECORD +66 -0
  64. fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
  65. fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
  66. fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
@@ -0,0 +1,120 @@
1
+ # This code is a modified copy of the `NLTKWordTokenizer` class from `NLTK` library.
2
+
3
+ import re
4
+
5
+
6
+ class SimpleTokenizer:
7
+ @staticmethod
8
+ def tokenize(text: str) -> list[str]:
9
+ text = re.sub(r"[^\w]", " ", text.lower())
10
+ text = re.sub(r"\s+", " ", text)
11
+
12
+ return text.strip().split()
13
+
14
+
15
+ class WordTokenizer:
16
+ """The tokenizer is "destructive" such that the regexes applied will munge the
17
+ input string to a state beyond re-construction.
18
+ """
19
+
20
+ # Starting quotes.
21
+ STARTING_QUOTES = [
22
+ (re.compile("([«“‘„]|[`]+)", re.U), r" \1 "),
23
+ (re.compile(r"^\""), r"``"),
24
+ (re.compile(r"(``)"), r" \1 "),
25
+ (re.compile(r"([ \(\[{<])(\"|\'{2})"), r"\1 `` "),
26
+ (re.compile(r"(?i)(\')(?!re|ve|ll|m|t|s|d|n)(\w)\b", re.U), r"\1 \2"),
27
+ ]
28
+
29
+ # Ending quotes.
30
+ ENDING_QUOTES = [
31
+ (re.compile("([»”’])", re.U), r" \1 "),
32
+ (re.compile(r"''"), " '' "),
33
+ (re.compile(r'"'), " '' "),
34
+ (re.compile(r"([^' ])('[sS]|'[mM]|'[dD]|') "), r"\1 \2 "),
35
+ (re.compile(r"([^' ])('ll|'LL|'re|'RE|'ve|'VE|n't|N'T) "), r"\1 \2 "),
36
+ ]
37
+
38
+ # Punctuation.
39
+ PUNCTUATION = [
40
+ (re.compile(r'([^\.])(\.)([\]\)}>"\'' "»”’ " r"]*)\s*$", re.U), r"\1 \2 \3 "),
41
+ (re.compile(r"([:,])([^\d])"), r" \1 \2"),
42
+ (re.compile(r"([:,])$"), r" \1 "),
43
+ (
44
+ re.compile(r"\.{2,}", re.U),
45
+ r" \g<0> ",
46
+ ),
47
+ (re.compile(r"[;@#$%&]"), r" \g<0> "),
48
+ (
49
+ re.compile(r'([^\.])(\.)([\]\)}>"\']*)\s*$'),
50
+ r"\1 \2\3 ",
51
+ ), # Handles the final period.
52
+ (re.compile(r"[?!]"), r" \g<0> "),
53
+ (re.compile(r"([^'])' "), r"\1 ' "),
54
+ (
55
+ re.compile(r"[*]", re.U),
56
+ r" \g<0> ",
57
+ ),
58
+ ]
59
+
60
+ # Pads parentheses
61
+ PARENS_BRACKETS = (re.compile(r"[\]\[\(\)\{\}\<\>]"), r" \g<0> ")
62
+ DOUBLE_DASHES = (re.compile(r"--"), r" -- ")
63
+
64
+ # List of contractions adapted from Robert MacIntyre's tokenizer.
65
+ CONTRACTIONS2 = [
66
+ re.compile(pattern)
67
+ for pattern in (
68
+ r"(?i)\b(can)(?#X)(not)\b",
69
+ r"(?i)\b(d)(?#X)('ye)\b",
70
+ r"(?i)\b(gim)(?#X)(me)\b",
71
+ r"(?i)\b(gon)(?#X)(na)\b",
72
+ r"(?i)\b(got)(?#X)(ta)\b",
73
+ r"(?i)\b(lem)(?#X)(me)\b",
74
+ r"(?i)\b(more)(?#X)('n)\b",
75
+ r"(?i)\b(wan)(?#X)(na)(?=\s)",
76
+ )
77
+ ]
78
+ CONTRACTIONS3 = [
79
+ re.compile(pattern) for pattern in (r"(?i) ('t)(?#X)(is)\b", r"(?i) ('t)(?#X)(was)\b")
80
+ ]
81
+
82
+ @classmethod
83
+ def tokenize(cls, text: str) -> list[str]:
84
+ """Return a tokenized copy of `text`.
85
+
86
+ >>> s = '''Good muffins cost $3.88 (roughly 3,36 euros)\nin New York.'''
87
+ >>> WordTokenizer().tokenize(s)
88
+ ['Good', 'muffins', 'cost', '$', '3.88', '(', 'roughly', '3,36', 'euros', ')', 'in', 'New', 'York', '.']
89
+
90
+ Args:
91
+ text: The text to be tokenized.
92
+
93
+ Returns:
94
+ A list of tokens.
95
+ """
96
+ for regexp, substitution in cls.STARTING_QUOTES:
97
+ text = regexp.sub(substitution, text)
98
+
99
+ for regexp, substitution in cls.PUNCTUATION:
100
+ text = regexp.sub(substitution, text)
101
+
102
+ # Handles parentheses.
103
+ regexp, substitution = cls.PARENS_BRACKETS
104
+ text = regexp.sub(substitution, text)
105
+
106
+ # Handles double dash.
107
+ regexp, substitution = cls.DOUBLE_DASHES
108
+ text = regexp.sub(substitution, text)
109
+
110
+ # add extra space to make things easier
111
+ text = " " + text + " "
112
+
113
+ for regexp, substitution in cls.ENDING_QUOTES:
114
+ text = regexp.sub(substitution, text)
115
+
116
+ for regexp in cls.CONTRACTIONS2:
117
+ text = regexp.sub(r" \1 \2 ", text)
118
+ for regexp in cls.CONTRACTIONS3:
119
+ text = regexp.sub(r" \1 \2 ", text)
120
+ return text.split()
@@ -0,0 +1,202 @@
1
+ from collections import defaultdict
2
+ from typing import Iterable
3
+
4
+ from py_rust_stemmers import SnowballStemmer
5
+ import numpy as np
6
+ from tokenizers import Tokenizer
7
+ from numpy.typing import NDArray
8
+
9
+ from fastembed.common.types import NumpyArray
10
+
11
+
12
+ class VocabTokenizerBase:
13
+ def tokenize(self, sentence: str) -> NumpyArray:
14
+ raise NotImplementedError()
15
+
16
+ def convert_ids_to_tokens(self, token_ids: NumpyArray) -> list[str]:
17
+ raise NotImplementedError()
18
+
19
+
20
+ class VocabTokenizer(VocabTokenizerBase):
21
+ def __init__(self, tokenizer: Tokenizer):
22
+ self.tokenizer = tokenizer
23
+
24
+ def tokenize(self, sentence: str) -> NumpyArray:
25
+ return np.array(self.tokenizer.encode(sentence).ids)
26
+
27
+ def convert_ids_to_tokens(self, token_ids: NumpyArray) -> list[str]:
28
+ return [self.tokenizer.id_to_token(token_id) for token_id in token_ids]
29
+
30
+
31
+ class VocabResolver:
32
+ def __init__(self, tokenizer: VocabTokenizerBase, stopwords: set[str], stemmer: SnowballStemmer):
33
+ # Word to id mapping
34
+ self.vocab: dict[str, int] = {}
35
+ # Id to word mapping
36
+ self.words: list[str] = []
37
+ # Lemma to word mapping
38
+ self.stem_mapping: dict[str, str] = {}
39
+ self.tokenizer: VocabTokenizerBase = tokenizer
40
+ self.stemmer = stemmer
41
+ self.stopwords: set[str] = stopwords
42
+
43
+ def tokenize(self, sentence: str) -> NumpyArray:
44
+ return self.tokenizer.tokenize(sentence)
45
+
46
+ def lookup_word(self, word_id: int) -> str:
47
+ if word_id == 0:
48
+ return "UNK"
49
+ return self.words[word_id - 1]
50
+
51
+ def convert_ids_to_tokens(self, token_ids: NumpyArray) -> list[str]:
52
+ return self.tokenizer.convert_ids_to_tokens(token_ids)
53
+
54
+ def vocab_size(self) -> int:
55
+ # We need +1 for UNK token
56
+ return len(self.vocab) + 1
57
+
58
+ def save_vocab(self, path: str) -> None:
59
+ with open(path, "w") as f:
60
+ for word in self.words:
61
+ f.write(word + "\n")
62
+
63
+ def save_json_vocab(self, path: str) -> None:
64
+ import json
65
+
66
+ with open(path, "w") as f:
67
+ json.dump({"vocab": self.words, "stem_mapping": self.stem_mapping}, f, indent=2)
68
+
69
+ def load_json_vocab(self, path: str) -> None:
70
+ import json
71
+
72
+ with open(path, "r") as f:
73
+ data = json.load(f)
74
+ self.words = data["vocab"]
75
+ self.vocab = {word: idx + 1 for idx, word in enumerate(self.words)}
76
+ self.stem_mapping = data["stem_mapping"]
77
+
78
+ def add_word(self, word: str) -> None:
79
+ if word not in self.vocab:
80
+ self.vocab[word] = len(self.vocab) + 1
81
+ self.words.append(word)
82
+ stem = self.stemmer.stem_word(word)
83
+ if stem not in self.stem_mapping:
84
+ self.stem_mapping[stem] = word
85
+ else:
86
+ existing_word = self.stem_mapping[stem]
87
+ if len(existing_word) > len(word):
88
+ # Prefer shorter words for the same stem
89
+ # Example: "swim" is preferred over "swimming"
90
+ self.stem_mapping[stem] = word
91
+
92
+ def load_vocab(self, path: str) -> None:
93
+ with open(path, "r") as f:
94
+ for line in f:
95
+ self.add_word(line.strip())
96
+
97
+ @classmethod
98
+ def _reconstruct_bpe(
99
+ cls, bpe_tokens: Iterable[tuple[int, str]]
100
+ ) -> list[tuple[str, list[int]]]:
101
+ result: list[tuple[str, list[int]]] = []
102
+ acc: str = ""
103
+ acc_idx: list[int] = []
104
+
105
+ continuing_subword_prefix = "##"
106
+ continuing_subword_prefix_len = len(continuing_subword_prefix)
107
+
108
+ for idx, token in bpe_tokens:
109
+ if token.startswith(continuing_subword_prefix):
110
+ acc += token[continuing_subword_prefix_len:]
111
+ acc_idx.append(idx)
112
+ else:
113
+ if acc:
114
+ result.append((acc, acc_idx))
115
+ acc_idx = []
116
+ acc = token
117
+ acc_idx.append(idx)
118
+
119
+ if acc:
120
+ result.append((acc, acc_idx))
121
+ return result
122
+
123
+ def resolve_tokens(
124
+ self, token_ids: NDArray[np.int64]
125
+ ) -> tuple[NDArray[np.int64], dict[int, int], dict[str, int], dict[str, list[str]]]:
126
+ """
127
+ Mark known tokens (including composed tokens) with vocab ids.
128
+
129
+ Args:
130
+ token_ids: (seq_len) - list of ids of tokens
131
+ Example:
132
+ [
133
+ 101, 3897, 19332, 12718, 23348,
134
+ 1010, 1996, 7151, 2296, 4845,
135
+ 2359, 2005, 4234, 1010, 4332,
136
+ 2871, 3191, 2062, 102
137
+ ]
138
+
139
+ returns:
140
+ - token_ids with vocab ids
141
+ [
142
+ 0, 151, 151, 0, 0,
143
+ 912, 0, 0, 0, 332,
144
+ 332, 332, 0, 7121, 191,
145
+ 0, 0, 332, 0
146
+ ]
147
+ - counts of each token
148
+ {
149
+ 151: 1,
150
+ 332: 3,
151
+ 7121: 1,
152
+ 191: 1,
153
+ 912: 1
154
+ }
155
+ - oov counts of each token
156
+ {
157
+ "the": 1,
158
+ "a": 1,
159
+ "[CLS]": 1,
160
+ "[SEP]": 1,
161
+ ...
162
+ }
163
+ - forms of each token
164
+ {
165
+ "hello": ["hello"],
166
+ "world": ["worlds", "world", "worlding"],
167
+ }
168
+
169
+ """
170
+ tokens = self.convert_ids_to_tokens(token_ids)
171
+ tokens_mapping = self._reconstruct_bpe(enumerate(tokens))
172
+
173
+ counts: dict[int, int] = defaultdict(int)
174
+ oov_count: dict[str, int] = defaultdict(int)
175
+
176
+ forms: dict[str, list[str]] = defaultdict(list)
177
+
178
+ for token, mapped_token_ids in tokens_mapping:
179
+ vocab_id = 0
180
+ if token in self.stopwords:
181
+ vocab_id = 0
182
+ elif token in self.vocab:
183
+ vocab_id = self.vocab[token]
184
+ forms[token].append(token)
185
+ elif token in self.stem_mapping:
186
+ vocab_id = self.vocab[self.stem_mapping[token]]
187
+ forms[self.stem_mapping[token]].append(token)
188
+ else:
189
+ stem = self.stemmer.stem_word(token)
190
+ if stem in self.stem_mapping:
191
+ vocab_id = self.vocab[self.stem_mapping[stem]]
192
+ forms[self.stem_mapping[stem]].append(token)
193
+
194
+ for token_id in mapped_token_ids:
195
+ token_ids[token_id] = vocab_id
196
+
197
+ if vocab_id == 0:
198
+ oov_count[token] += 1
199
+ else:
200
+ counts[vocab_id] += 1
201
+ return token_ids, counts, oov_count, forms
202
+
@@ -0,0 +1,3 @@
1
+ from fastembed.text.text_embedding import TextEmbedding
2
+
3
+ __all__ = ["TextEmbedding"]
@@ -0,0 +1,56 @@
1
+ from typing import Any, Iterable, Type
2
+
3
+ from fastembed.common.types import NumpyArray
4
+ from fastembed.common.onnx_model import OnnxOutputContext
5
+ from fastembed.text.onnx_embedding import OnnxTextEmbedding, OnnxTextEmbeddingWorker
6
+ from fastembed.common.model_description import DenseModelDescription, ModelSource
7
+
8
+ supported_clip_models: list[DenseModelDescription] = [
9
+ DenseModelDescription(
10
+ model="Qdrant/clip-ViT-B-32-text",
11
+ dim=512,
12
+ description=(
13
+ "Text embeddings, Multimodal (text&image), English, 77 input tokens truncation, "
14
+ "Prefixes for queries/documents: not necessary, 2021 year"
15
+ ),
16
+ license="mit",
17
+ size_in_GB=0.25,
18
+ sources=ModelSource(hf="Qdrant/clip-ViT-B-32-text"),
19
+ model_file="model.onnx",
20
+ ),
21
+ ]
22
+
23
+
24
+ class CLIPOnnxEmbedding(OnnxTextEmbedding):
25
+ @classmethod
26
+ def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
27
+ return CLIPEmbeddingWorker
28
+
29
+ @classmethod
30
+ def _list_supported_models(cls) -> list[DenseModelDescription]:
31
+ """Lists the supported models.
32
+
33
+ Returns:
34
+ list[DenseModelDescription]: A list of DenseModelDescription objects containing the model information.
35
+ """
36
+ return supported_clip_models
37
+
38
+ def _post_process_onnx_output(
39
+ self, output: OnnxOutputContext, **kwargs: Any
40
+ ) -> Iterable[NumpyArray]:
41
+ return output.model_output
42
+
43
+
44
+ class CLIPEmbeddingWorker(OnnxTextEmbeddingWorker):
45
+ def init_embedding(
46
+ self,
47
+ model_name: str,
48
+ cache_dir: str,
49
+ **kwargs: Any,
50
+ ) -> OnnxTextEmbedding:
51
+ return CLIPOnnxEmbedding(
52
+ model_name=model_name,
53
+ cache_dir=cache_dir,
54
+ threads=1,
55
+ **kwargs,
56
+ )
@@ -0,0 +1,97 @@
1
+ from typing import Sequence, Any, Iterable
2
+ from dataclasses import dataclass
3
+
4
+ import numpy as np
5
+ from numpy.typing import NDArray
6
+
7
+ from fastembed.common import OnnxProvider
8
+ from fastembed.common.model_description import (
9
+ PoolingType,
10
+ DenseModelDescription,
11
+ )
12
+ from fastembed.common.onnx_model import OnnxOutputContext
13
+ from fastembed.common.types import NumpyArray, Device
14
+ from fastembed.common.utils import normalize, mean_pooling
15
+ from fastembed.text.onnx_embedding import OnnxTextEmbedding
16
+
17
+
18
+ @dataclass(frozen=True)
19
+ class PostprocessingConfig:
20
+ pooling: PoolingType
21
+ normalization: bool
22
+
23
+
24
+ class CustomTextEmbedding(OnnxTextEmbedding):
25
+ SUPPORTED_MODELS: list[DenseModelDescription] = []
26
+ POSTPROCESSING_MAPPING: dict[str, PostprocessingConfig] = {}
27
+
28
+ def __init__(
29
+ self,
30
+ model_name: str,
31
+ cache_dir: str | None = None,
32
+ threads: int | None = None,
33
+ providers: Sequence[OnnxProvider] | None = None,
34
+ cuda: bool | Device = Device.AUTO,
35
+ device_ids: list[int] | None = None,
36
+ lazy_load: bool = False,
37
+ device_id: int | None = None,
38
+ specific_model_path: str | None = None,
39
+ **kwargs: Any,
40
+ ):
41
+ super().__init__(
42
+ model_name=model_name,
43
+ cache_dir=cache_dir,
44
+ threads=threads,
45
+ providers=providers,
46
+ cuda=cuda,
47
+ device_ids=device_ids,
48
+ lazy_load=lazy_load,
49
+ device_id=device_id,
50
+ specific_model_path=specific_model_path,
51
+ **kwargs,
52
+ )
53
+ self._pooling = self.POSTPROCESSING_MAPPING[model_name].pooling
54
+ self._normalization = self.POSTPROCESSING_MAPPING[model_name].normalization
55
+
56
+ @classmethod
57
+ def _list_supported_models(cls) -> list[DenseModelDescription]:
58
+ return cls.SUPPORTED_MODELS
59
+
60
+ def _post_process_onnx_output(
61
+ self, output: OnnxOutputContext, **kwargs: Any
62
+ ) -> Iterable[NumpyArray]:
63
+ return self._normalize(self._pool(output.model_output, output.attention_mask))
64
+
65
+ def _pool(
66
+ self, embeddings: NumpyArray, attention_mask: NDArray[np.int64] | None = None
67
+ ) -> NumpyArray:
68
+ if self._pooling == PoolingType.CLS:
69
+ return embeddings[:, 0]
70
+
71
+ if self._pooling == PoolingType.MEAN:
72
+ if attention_mask is None:
73
+ raise ValueError("attention_mask must be provided for mean pooling")
74
+ return mean_pooling(embeddings, attention_mask)
75
+
76
+ if self._pooling == PoolingType.DISABLED:
77
+ return embeddings
78
+
79
+ raise ValueError(
80
+ f"Unsupported pooling type {self._pooling}. "
81
+ f"Supported types are: {PoolingType.CLS}, {PoolingType.MEAN}, {PoolingType.DISABLED}."
82
+ )
83
+
84
+ def _normalize(self, embeddings: NumpyArray) -> NumpyArray:
85
+ return normalize(embeddings) if self._normalization else embeddings
86
+
87
+ @classmethod
88
+ def add_model(
89
+ cls,
90
+ model_description: DenseModelDescription,
91
+ pooling: PoolingType,
92
+ normalization: bool,
93
+ ) -> None:
94
+ cls.SUPPORTED_MODELS.append(model_description)
95
+ cls.POSTPROCESSING_MAPPING[model_description.model] = PostprocessingConfig(
96
+ pooling=pooling, normalization=normalization
97
+ )
@@ -0,0 +1,109 @@
1
+ from enum import Enum
2
+ from typing import Any, Type, Iterable
3
+
4
+ import numpy as np
5
+
6
+ from fastembed.common.onnx_model import OnnxOutputContext
7
+ from fastembed.common.types import NumpyArray
8
+ from fastembed.text.pooled_normalized_embedding import PooledNormalizedEmbedding
9
+ from fastembed.text.onnx_embedding import OnnxTextEmbeddingWorker
10
+ from fastembed.common.model_description import DenseModelDescription, ModelSource
11
+
12
+ supported_multitask_models: list[DenseModelDescription] = [
13
+ DenseModelDescription(
14
+ model="jinaai/jina-embeddings-v3",
15
+ dim=1024,
16
+ tasks={
17
+ "retrieval.query": 0,
18
+ "retrieval.passage": 1,
19
+ "separation": 2,
20
+ "classification": 3,
21
+ "text-matching": 4,
22
+ },
23
+ description=(
24
+ "Multi-task unimodal (text) embedding model, multi-lingual (~100), "
25
+ "1024 tokens truncation, and 8192 sequence length. Prefixes for queries/documents: not necessary, 2024 year."
26
+ ),
27
+ license="cc-by-nc-4.0",
28
+ size_in_GB=2.29,
29
+ sources=ModelSource(hf="jinaai/jina-embeddings-v3"),
30
+ model_file="onnx/model.onnx",
31
+ additional_files=["onnx/model.onnx_data"],
32
+ ),
33
+ ]
34
+
35
+
36
+ class Task(int, Enum):
37
+ RETRIEVAL_QUERY = 0
38
+ RETRIEVAL_PASSAGE = 1
39
+ SEPARATION = 2
40
+ CLASSIFICATION = 3
41
+ TEXT_MATCHING = 4
42
+
43
+
44
+ class JinaEmbeddingV3(PooledNormalizedEmbedding):
45
+ PASSAGE_TASK = Task.RETRIEVAL_PASSAGE
46
+ QUERY_TASK = Task.RETRIEVAL_QUERY
47
+
48
+ def __init__(self, *args: Any, task_id: int | None = None, **kwargs: Any):
49
+ super().__init__(*args, **kwargs)
50
+ self.default_task_id: Task | int = task_id if task_id is not None else self.PASSAGE_TASK
51
+
52
+ @classmethod
53
+ def _get_worker_class(cls) -> Type[OnnxTextEmbeddingWorker]:
54
+ return JinaEmbeddingV3Worker
55
+
56
+ @classmethod
57
+ def _list_supported_models(cls) -> list[DenseModelDescription]:
58
+ return supported_multitask_models
59
+
60
+ def _preprocess_onnx_input(
61
+ self,
62
+ onnx_input: dict[str, NumpyArray],
63
+ task_id: int | Task | None = None,
64
+ **kwargs: Any,
65
+ ) -> dict[str, NumpyArray]:
66
+ if task_id is None:
67
+ raise ValueError(f"task_id must be provided for JinaEmbeddingV3, got <{task_id}>")
68
+ onnx_input["task_id"] = np.array(task_id, dtype=np.int64)
69
+ return onnx_input
70
+
71
+ def embed(
72
+ self,
73
+ documents: str | Iterable[str],
74
+ batch_size: int = 256,
75
+ parallel: int | None = None,
76
+ task_id: int | None = None,
77
+ **kwargs: Any,
78
+ ) -> Iterable[NumpyArray]:
79
+ task_id = (
80
+ task_id if task_id is not None else self.default_task_id
81
+ ) # required for multiprocessing
82
+ yield from super().embed(documents, batch_size, parallel, task_id=task_id, **kwargs)
83
+
84
+ def query_embed(self, query: str | Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
85
+ yield from super().embed(query, task_id=self.QUERY_TASK, **kwargs)
86
+
87
+ def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
88
+ yield from super().embed(texts, task_id=self.PASSAGE_TASK, **kwargs)
89
+
90
+
91
+ class JinaEmbeddingV3Worker(OnnxTextEmbeddingWorker):
92
+ def init_embedding(
93
+ self,
94
+ model_name: str,
95
+ cache_dir: str,
96
+ **kwargs: Any,
97
+ ) -> JinaEmbeddingV3:
98
+ return JinaEmbeddingV3(
99
+ model_name=model_name,
100
+ cache_dir=cache_dir,
101
+ threads=1,
102
+ **kwargs,
103
+ )
104
+
105
+ def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, OnnxOutputContext]]:
106
+ self.model: JinaEmbeddingV3 # mypy complaints `self.model` does not have `default_task_id`
107
+ for idx, batch in items:
108
+ onnx_output = self.model.onnx_embed(batch, task_id=self.model.default_task_id)
109
+ yield idx, onnx_output