fastembed-bio 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastembed/__init__.py +24 -0
- fastembed/bio/__init__.py +3 -0
- fastembed/bio/protein_embedding.py +456 -0
- fastembed/common/__init__.py +3 -0
- fastembed/common/model_description.py +52 -0
- fastembed/common/model_management.py +471 -0
- fastembed/common/onnx_model.py +188 -0
- fastembed/common/preprocessor_utils.py +84 -0
- fastembed/common/types.py +27 -0
- fastembed/common/utils.py +69 -0
- fastembed/embedding.py +24 -0
- fastembed/image/__init__.py +3 -0
- fastembed/image/image_embedding.py +135 -0
- fastembed/image/image_embedding_base.py +55 -0
- fastembed/image/onnx_embedding.py +217 -0
- fastembed/image/onnx_image_model.py +156 -0
- fastembed/image/transform/functional.py +221 -0
- fastembed/image/transform/operators.py +499 -0
- fastembed/late_interaction/__init__.py +5 -0
- fastembed/late_interaction/colbert.py +301 -0
- fastembed/late_interaction/jina_colbert.py +58 -0
- fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
- fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
- fastembed/late_interaction/token_embeddings.py +83 -0
- fastembed/late_interaction_multimodal/__init__.py +5 -0
- fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
- fastembed/late_interaction_multimodal/colpali.py +327 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
- fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
- fastembed/parallel_processor.py +253 -0
- fastembed/postprocess/__init__.py +3 -0
- fastembed/postprocess/muvera.py +362 -0
- fastembed/py.typed +1 -0
- fastembed/rerank/cross_encoder/__init__.py +3 -0
- fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
- fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
- fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
- fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
- fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
- fastembed/sparse/__init__.py +4 -0
- fastembed/sparse/bm25.py +359 -0
- fastembed/sparse/bm42.py +369 -0
- fastembed/sparse/minicoil.py +372 -0
- fastembed/sparse/sparse_embedding_base.py +90 -0
- fastembed/sparse/sparse_text_embedding.py +143 -0
- fastembed/sparse/splade_pp.py +196 -0
- fastembed/sparse/utils/minicoil_encoder.py +146 -0
- fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
- fastembed/sparse/utils/tokenizer.py +120 -0
- fastembed/sparse/utils/vocab_resolver.py +202 -0
- fastembed/text/__init__.py +3 -0
- fastembed/text/clip_embedding.py +56 -0
- fastembed/text/custom_text_embedding.py +97 -0
- fastembed/text/multitask_embedding.py +109 -0
- fastembed/text/onnx_embedding.py +353 -0
- fastembed/text/onnx_text_model.py +180 -0
- fastembed/text/pooled_embedding.py +136 -0
- fastembed/text/pooled_normalized_embedding.py +164 -0
- fastembed/text/text_embedding.py +228 -0
- fastembed/text/text_embedding_base.py +75 -0
- fastembed_bio-0.1.0.dist-info/METADATA +339 -0
- fastembed_bio-0.1.0.dist-info/RECORD +66 -0
- fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
- fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
- fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
|
@@ -0,0 +1,471 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import time
|
|
3
|
+
import json
|
|
4
|
+
import shutil
|
|
5
|
+
import tarfile
|
|
6
|
+
from copy import deepcopy
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, TypeVar, Generic
|
|
9
|
+
|
|
10
|
+
import requests
|
|
11
|
+
from huggingface_hub import snapshot_download, model_info, list_repo_tree
|
|
12
|
+
from huggingface_hub.hf_api import RepoFile
|
|
13
|
+
from huggingface_hub.utils import (
|
|
14
|
+
RepositoryNotFoundError,
|
|
15
|
+
disable_progress_bars,
|
|
16
|
+
enable_progress_bars,
|
|
17
|
+
)
|
|
18
|
+
from loguru import logger
|
|
19
|
+
from tqdm import tqdm
|
|
20
|
+
from fastembed.common.model_description import BaseModelDescription
|
|
21
|
+
|
|
22
|
+
T = TypeVar("T", bound=BaseModelDescription)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class ModelManagement(Generic[T]):
|
|
26
|
+
METADATA_FILE = "files_metadata.json"
|
|
27
|
+
|
|
28
|
+
@classmethod
|
|
29
|
+
def list_supported_models(cls) -> list[dict[str, Any]]:
|
|
30
|
+
"""Lists the supported models.
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
list[T]: A list of dictionaries containing the model information.
|
|
34
|
+
"""
|
|
35
|
+
raise NotImplementedError()
|
|
36
|
+
|
|
37
|
+
@classmethod
|
|
38
|
+
def add_custom_model(
|
|
39
|
+
cls,
|
|
40
|
+
*args: Any,
|
|
41
|
+
**kwargs: Any,
|
|
42
|
+
) -> None:
|
|
43
|
+
"""Add a custom model to the existing embedding classes based on the passed model descriptions
|
|
44
|
+
|
|
45
|
+
Model description dict should contain the fields same as in one of the model descriptions presented
|
|
46
|
+
in fastembed.common.model_description
|
|
47
|
+
|
|
48
|
+
E.g. for BaseModelDescription:
|
|
49
|
+
model: str
|
|
50
|
+
sources: ModelSource
|
|
51
|
+
model_file: str
|
|
52
|
+
description: str
|
|
53
|
+
license: str
|
|
54
|
+
size_in_GB: float
|
|
55
|
+
additional_files: list[str]
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
None
|
|
59
|
+
"""
|
|
60
|
+
raise NotImplementedError()
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
def _list_supported_models(cls) -> list[T]:
|
|
64
|
+
raise NotImplementedError()
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
def _get_model_description(cls, model_name: str) -> T:
|
|
68
|
+
"""
|
|
69
|
+
Gets the model description from the model_name.
|
|
70
|
+
|
|
71
|
+
Args:
|
|
72
|
+
model_name (str): The name of the model.
|
|
73
|
+
|
|
74
|
+
raises:
|
|
75
|
+
ValueError: If the model_name is not supported.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
T: The model description.
|
|
79
|
+
"""
|
|
80
|
+
for model in cls._list_supported_models():
|
|
81
|
+
if model_name.lower() == model.model.lower():
|
|
82
|
+
return model
|
|
83
|
+
|
|
84
|
+
raise ValueError(f"Model {model_name} is not supported in {cls.__name__}.")
|
|
85
|
+
|
|
86
|
+
@classmethod
|
|
87
|
+
def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool = True) -> str:
|
|
88
|
+
"""
|
|
89
|
+
Downloads a file from Google Cloud Storage.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
url (str): The URL to download the file from.
|
|
93
|
+
output_path (str): The path to save the downloaded file to.
|
|
94
|
+
show_progress (bool, optional): Whether to show a progress bar. Defaults to True.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
str: The path to the downloaded file.
|
|
98
|
+
"""
|
|
99
|
+
|
|
100
|
+
if os.path.exists(output_path):
|
|
101
|
+
return output_path
|
|
102
|
+
response = requests.get(url, stream=True)
|
|
103
|
+
|
|
104
|
+
# Handle HTTP errors
|
|
105
|
+
if response.status_code == 403:
|
|
106
|
+
raise PermissionError(
|
|
107
|
+
"Authentication Error: You do not have permission to access this resource. "
|
|
108
|
+
"Please check your credentials."
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Get the total size of the file
|
|
112
|
+
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
|
113
|
+
|
|
114
|
+
# Warn if the total size is zero
|
|
115
|
+
if total_size_in_bytes == 0:
|
|
116
|
+
print(f"Warning: Content-length header is missing or zero in the response from {url}.")
|
|
117
|
+
|
|
118
|
+
show_progress = bool(total_size_in_bytes and show_progress)
|
|
119
|
+
|
|
120
|
+
with tqdm(
|
|
121
|
+
total=total_size_in_bytes,
|
|
122
|
+
unit="iB",
|
|
123
|
+
unit_scale=True,
|
|
124
|
+
disable=not show_progress,
|
|
125
|
+
) as progress_bar:
|
|
126
|
+
with open(output_path, "wb") as file:
|
|
127
|
+
for chunk in response.iter_content(chunk_size=1024):
|
|
128
|
+
if chunk: # Filter out keep-alive new chunks
|
|
129
|
+
progress_bar.update(len(chunk))
|
|
130
|
+
file.write(chunk)
|
|
131
|
+
return output_path
|
|
132
|
+
|
|
133
|
+
@classmethod
|
|
134
|
+
def download_files_from_huggingface(
|
|
135
|
+
cls,
|
|
136
|
+
hf_source_repo: str,
|
|
137
|
+
cache_dir: str,
|
|
138
|
+
extra_patterns: list[str],
|
|
139
|
+
local_files_only: bool = False,
|
|
140
|
+
**kwargs: Any,
|
|
141
|
+
) -> str:
|
|
142
|
+
"""
|
|
143
|
+
Downloads a model from HuggingFace Hub.
|
|
144
|
+
Args:
|
|
145
|
+
hf_source_repo (str): Name of the model on HuggingFace Hub, e.g. "qdrant/all-MiniLM-L6-v2-onnx".
|
|
146
|
+
cache_dir (Optional[str]): The path to the cache directory.
|
|
147
|
+
extra_patterns (list[str]): extra patterns to allow in the snapshot download, typically
|
|
148
|
+
includes the required model files.
|
|
149
|
+
local_files_only (bool, optional): Whether to only use local files. Defaults to False.
|
|
150
|
+
Returns:
|
|
151
|
+
Path: The path to the model directory.
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
def _verify_files_from_metadata(
|
|
155
|
+
model_dir: Path, stored_metadata: dict[str, Any], repo_files: list[RepoFile]
|
|
156
|
+
) -> bool:
|
|
157
|
+
try:
|
|
158
|
+
for rel_path, meta in stored_metadata.items():
|
|
159
|
+
file_path = model_dir / rel_path
|
|
160
|
+
|
|
161
|
+
if not file_path.exists():
|
|
162
|
+
return False
|
|
163
|
+
|
|
164
|
+
if repo_files: # online verification
|
|
165
|
+
file_info = next((f for f in repo_files if f.path == file_path.name), None)
|
|
166
|
+
if (
|
|
167
|
+
not file_info
|
|
168
|
+
or file_info.size != meta["size"]
|
|
169
|
+
or file_info.blob_id != meta["blob_id"]
|
|
170
|
+
):
|
|
171
|
+
return False
|
|
172
|
+
|
|
173
|
+
else: # offline verification
|
|
174
|
+
if file_path.stat().st_size != meta["size"]:
|
|
175
|
+
return False
|
|
176
|
+
return True
|
|
177
|
+
except (OSError, KeyError) as e:
|
|
178
|
+
logger.error(f"Error verifying files: {str(e)}")
|
|
179
|
+
return False
|
|
180
|
+
|
|
181
|
+
def _collect_file_metadata(
|
|
182
|
+
model_dir: Path, repo_files: list[RepoFile]
|
|
183
|
+
) -> dict[str, dict[str, int | str]]:
|
|
184
|
+
meta: dict[str, dict[str, int | str]] = {}
|
|
185
|
+
file_info_map = {f.path: f for f in repo_files}
|
|
186
|
+
for file_path in model_dir.rglob("*"):
|
|
187
|
+
if file_path.is_file() and file_path.name != cls.METADATA_FILE:
|
|
188
|
+
repo_file = file_info_map.get(file_path.name)
|
|
189
|
+
if repo_file:
|
|
190
|
+
meta[str(file_path.relative_to(model_dir))] = {
|
|
191
|
+
"size": repo_file.size,
|
|
192
|
+
"blob_id": repo_file.blob_id,
|
|
193
|
+
}
|
|
194
|
+
return meta
|
|
195
|
+
|
|
196
|
+
def _save_file_metadata(model_dir: Path, meta: dict[str, dict[str, int | str]]) -> None:
|
|
197
|
+
try:
|
|
198
|
+
if not model_dir.exists():
|
|
199
|
+
model_dir.mkdir(parents=True, exist_ok=True)
|
|
200
|
+
(model_dir / cls.METADATA_FILE).write_text(json.dumps(meta))
|
|
201
|
+
except (OSError, ValueError) as e:
|
|
202
|
+
logger.warning(f"Error saving metadata: {str(e)}")
|
|
203
|
+
|
|
204
|
+
allow_patterns = [
|
|
205
|
+
"config.json",
|
|
206
|
+
"tokenizer.json",
|
|
207
|
+
"tokenizer_config.json",
|
|
208
|
+
"special_tokens_map.json",
|
|
209
|
+
"preprocessor_config.json",
|
|
210
|
+
]
|
|
211
|
+
|
|
212
|
+
allow_patterns.extend(extra_patterns)
|
|
213
|
+
|
|
214
|
+
snapshot_dir = Path(cache_dir) / f"models--{hf_source_repo.replace('/', '--')}"
|
|
215
|
+
metadata_file = snapshot_dir / cls.METADATA_FILE
|
|
216
|
+
|
|
217
|
+
if local_files_only:
|
|
218
|
+
disable_progress_bars()
|
|
219
|
+
if metadata_file.exists():
|
|
220
|
+
metadata = json.loads(metadata_file.read_text())
|
|
221
|
+
verified = _verify_files_from_metadata(snapshot_dir, metadata, repo_files=[])
|
|
222
|
+
if not verified:
|
|
223
|
+
logger.warning(
|
|
224
|
+
"Local file sizes do not match the metadata."
|
|
225
|
+
) # do not raise, still make an attempt to load the model
|
|
226
|
+
result = snapshot_download(
|
|
227
|
+
repo_id=hf_source_repo,
|
|
228
|
+
allow_patterns=allow_patterns,
|
|
229
|
+
cache_dir=cache_dir,
|
|
230
|
+
local_files_only=local_files_only,
|
|
231
|
+
**kwargs,
|
|
232
|
+
)
|
|
233
|
+
return result
|
|
234
|
+
|
|
235
|
+
repo_revision = model_info(hf_source_repo).sha
|
|
236
|
+
repo_tree = list(list_repo_tree(hf_source_repo, revision=repo_revision, repo_type="model"))
|
|
237
|
+
|
|
238
|
+
allowed_extensions = {".json", ".onnx", ".txt"}
|
|
239
|
+
repo_files = (
|
|
240
|
+
[
|
|
241
|
+
f
|
|
242
|
+
for f in repo_tree
|
|
243
|
+
if isinstance(f, RepoFile) and Path(f.path).suffix in allowed_extensions
|
|
244
|
+
]
|
|
245
|
+
if repo_tree
|
|
246
|
+
else []
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
verified_metadata = False
|
|
250
|
+
|
|
251
|
+
if snapshot_dir.exists() and metadata_file.exists():
|
|
252
|
+
metadata = json.loads(metadata_file.read_text())
|
|
253
|
+
verified_metadata = _verify_files_from_metadata(snapshot_dir, metadata, repo_files)
|
|
254
|
+
|
|
255
|
+
if verified_metadata:
|
|
256
|
+
disable_progress_bars()
|
|
257
|
+
|
|
258
|
+
result = snapshot_download(
|
|
259
|
+
repo_id=hf_source_repo,
|
|
260
|
+
allow_patterns=allow_patterns,
|
|
261
|
+
cache_dir=cache_dir,
|
|
262
|
+
local_files_only=local_files_only,
|
|
263
|
+
**kwargs,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
if (
|
|
267
|
+
not verified_metadata
|
|
268
|
+
): # metadata is not up-to-date, update it and check whether the files have been
|
|
269
|
+
# downloaded correctly
|
|
270
|
+
metadata = _collect_file_metadata(snapshot_dir, repo_files)
|
|
271
|
+
|
|
272
|
+
download_successful = _verify_files_from_metadata(
|
|
273
|
+
snapshot_dir, metadata, repo_files=[]
|
|
274
|
+
) # offline verification
|
|
275
|
+
if not download_successful:
|
|
276
|
+
raise ValueError(
|
|
277
|
+
"Files have been corrupted during downloading process. "
|
|
278
|
+
"Please check your internet connection and try again."
|
|
279
|
+
)
|
|
280
|
+
_save_file_metadata(snapshot_dir, metadata)
|
|
281
|
+
|
|
282
|
+
return result
|
|
283
|
+
|
|
284
|
+
@classmethod
|
|
285
|
+
def decompress_to_cache(cls, targz_path: str, cache_dir: str) -> str:
|
|
286
|
+
"""
|
|
287
|
+
Decompresses a .tar.gz file to a cache directory.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
targz_path (str): Path to the .tar.gz file.
|
|
291
|
+
cache_dir (str): Path to the cache directory.
|
|
292
|
+
|
|
293
|
+
Returns:
|
|
294
|
+
cache_dir (str): Path to the cache directory.
|
|
295
|
+
"""
|
|
296
|
+
# Check if targz_path exists and is a file
|
|
297
|
+
if not os.path.isfile(targz_path):
|
|
298
|
+
raise ValueError(f"{targz_path} does not exist or is not a file.")
|
|
299
|
+
|
|
300
|
+
# Check if targz_path is a .tar.gz file
|
|
301
|
+
if not targz_path.endswith(".tar.gz"):
|
|
302
|
+
raise ValueError(f"{targz_path} is not a .tar.gz file.")
|
|
303
|
+
|
|
304
|
+
try:
|
|
305
|
+
# Open the tar.gz file
|
|
306
|
+
with tarfile.open(targz_path, "r:gz") as tar:
|
|
307
|
+
# Extract all files into the cache directory
|
|
308
|
+
tar.extractall(
|
|
309
|
+
path=cache_dir,
|
|
310
|
+
)
|
|
311
|
+
except tarfile.TarError as e:
|
|
312
|
+
# If any error occurs while opening or extracting the tar.gz file,
|
|
313
|
+
# delete the cache directory (if it was created in this function)
|
|
314
|
+
# and raise the error again
|
|
315
|
+
if "tmp" in cache_dir:
|
|
316
|
+
shutil.rmtree(cache_dir)
|
|
317
|
+
raise ValueError(f"An error occurred while decompressing {targz_path}: {e}")
|
|
318
|
+
|
|
319
|
+
return cache_dir
|
|
320
|
+
|
|
321
|
+
@classmethod
|
|
322
|
+
def retrieve_model_gcs(
|
|
323
|
+
cls,
|
|
324
|
+
model_name: str,
|
|
325
|
+
source_url: str,
|
|
326
|
+
cache_dir: str,
|
|
327
|
+
deprecated_tar_struct: bool = False,
|
|
328
|
+
local_files_only: bool = False,
|
|
329
|
+
) -> Path:
|
|
330
|
+
fast_model_name = f"{'fast-' if deprecated_tar_struct else ''}{model_name.split('/')[-1]}"
|
|
331
|
+
cache_tmp_dir = Path(cache_dir) / "tmp"
|
|
332
|
+
model_tmp_dir = cache_tmp_dir / fast_model_name
|
|
333
|
+
model_dir = Path(cache_dir) / fast_model_name
|
|
334
|
+
|
|
335
|
+
# check if the model_dir and the model files are both present for macOS
|
|
336
|
+
if model_dir.exists() and len(list(model_dir.glob("*"))) > 0:
|
|
337
|
+
return model_dir
|
|
338
|
+
|
|
339
|
+
if model_tmp_dir.exists():
|
|
340
|
+
shutil.rmtree(model_tmp_dir)
|
|
341
|
+
|
|
342
|
+
cache_tmp_dir.mkdir(parents=True, exist_ok=True)
|
|
343
|
+
|
|
344
|
+
model_tar_gz = Path(cache_dir) / f"{fast_model_name}.tar.gz"
|
|
345
|
+
|
|
346
|
+
if model_tar_gz.exists():
|
|
347
|
+
model_tar_gz.unlink()
|
|
348
|
+
|
|
349
|
+
if not local_files_only:
|
|
350
|
+
cls.download_file_from_gcs(
|
|
351
|
+
source_url,
|
|
352
|
+
output_path=str(model_tar_gz),
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
cls.decompress_to_cache(targz_path=str(model_tar_gz), cache_dir=str(cache_tmp_dir))
|
|
356
|
+
assert model_tmp_dir.exists(), f"Could not find {model_tmp_dir} in {cache_tmp_dir}"
|
|
357
|
+
|
|
358
|
+
model_tar_gz.unlink()
|
|
359
|
+
# Rename from tmp to final name is atomic
|
|
360
|
+
model_tmp_dir.rename(model_dir)
|
|
361
|
+
else:
|
|
362
|
+
logger.error(
|
|
363
|
+
f"Could not find the model tar.gz file at {model_dir} and local_files_only=True."
|
|
364
|
+
)
|
|
365
|
+
raise ValueError(
|
|
366
|
+
f"Could not find the model tar.gz file at {model_dir} and local_files_only=True."
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
return model_dir
|
|
370
|
+
|
|
371
|
+
@classmethod
|
|
372
|
+
def download_model(cls, model: T, cache_dir: str, retries: int = 3, **kwargs: Any) -> Path:
|
|
373
|
+
"""
|
|
374
|
+
Downloads a model from HuggingFace Hub or Google Cloud Storage.
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
model (T): The model description.
|
|
378
|
+
Example:
|
|
379
|
+
```
|
|
380
|
+
{
|
|
381
|
+
"model": "BAAI/bge-base-en-v1.5",
|
|
382
|
+
"dim": 768,
|
|
383
|
+
"description": "Base English model, v1.5",
|
|
384
|
+
"size_in_GB": 0.44,
|
|
385
|
+
"sources": {
|
|
386
|
+
"url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en-v1.5.tar.gz",
|
|
387
|
+
"hf": "qdrant/bge-base-en-v1.5-onnx-q",
|
|
388
|
+
}
|
|
389
|
+
}
|
|
390
|
+
```
|
|
391
|
+
cache_dir (str): The path to the cache directory.
|
|
392
|
+
retries: (int): The number of times to retry (including the first attempt)
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
Path: The path to the downloaded model directory.
|
|
396
|
+
"""
|
|
397
|
+
local_files_only = kwargs.get("local_files_only", False)
|
|
398
|
+
specific_model_path: str | None = kwargs.pop("specific_model_path", None)
|
|
399
|
+
if specific_model_path:
|
|
400
|
+
return Path(specific_model_path)
|
|
401
|
+
retries = 1 if local_files_only else retries
|
|
402
|
+
hf_source = model.sources.hf
|
|
403
|
+
url_source = model.sources.url
|
|
404
|
+
|
|
405
|
+
extra_patterns = [model.model_file]
|
|
406
|
+
extra_patterns.extend(model.additional_files)
|
|
407
|
+
|
|
408
|
+
if hf_source:
|
|
409
|
+
try:
|
|
410
|
+
cache_kwargs = deepcopy(kwargs)
|
|
411
|
+
cache_kwargs["local_files_only"] = True
|
|
412
|
+
return Path(
|
|
413
|
+
cls.download_files_from_huggingface(
|
|
414
|
+
hf_source,
|
|
415
|
+
cache_dir=cache_dir,
|
|
416
|
+
extra_patterns=extra_patterns,
|
|
417
|
+
**cache_kwargs,
|
|
418
|
+
)
|
|
419
|
+
)
|
|
420
|
+
except Exception:
|
|
421
|
+
pass
|
|
422
|
+
finally:
|
|
423
|
+
enable_progress_bars()
|
|
424
|
+
|
|
425
|
+
sleep = 3.0
|
|
426
|
+
while retries > 0:
|
|
427
|
+
retries -= 1
|
|
428
|
+
|
|
429
|
+
if hf_source and not local_files_only:
|
|
430
|
+
# we have already tried loading with `local_files_only=True` via hf and we failed
|
|
431
|
+
try:
|
|
432
|
+
return Path(
|
|
433
|
+
cls.download_files_from_huggingface(
|
|
434
|
+
hf_source,
|
|
435
|
+
cache_dir=cache_dir,
|
|
436
|
+
extra_patterns=extra_patterns,
|
|
437
|
+
**kwargs,
|
|
438
|
+
)
|
|
439
|
+
)
|
|
440
|
+
except (EnvironmentError, RepositoryNotFoundError, ValueError) as e:
|
|
441
|
+
if not local_files_only:
|
|
442
|
+
logger.error(
|
|
443
|
+
f"Could not download model from HuggingFace: {e} "
|
|
444
|
+
"Falling back to other sources."
|
|
445
|
+
)
|
|
446
|
+
finally:
|
|
447
|
+
enable_progress_bars()
|
|
448
|
+
if url_source or local_files_only:
|
|
449
|
+
try:
|
|
450
|
+
return cls.retrieve_model_gcs(
|
|
451
|
+
model.model,
|
|
452
|
+
str(url_source),
|
|
453
|
+
str(cache_dir),
|
|
454
|
+
deprecated_tar_struct=model.sources.deprecated_tar_struct,
|
|
455
|
+
local_files_only=local_files_only,
|
|
456
|
+
)
|
|
457
|
+
except Exception:
|
|
458
|
+
if not local_files_only:
|
|
459
|
+
logger.error(f"Could not download model from url: {url_source}")
|
|
460
|
+
|
|
461
|
+
if local_files_only:
|
|
462
|
+
logger.error("Could not find model in cache_dir")
|
|
463
|
+
break
|
|
464
|
+
else:
|
|
465
|
+
logger.error(
|
|
466
|
+
f"Could not download model from either source, sleeping for {sleep} seconds, {retries} retries left."
|
|
467
|
+
)
|
|
468
|
+
time.sleep(sleep)
|
|
469
|
+
sleep *= 3
|
|
470
|
+
|
|
471
|
+
raise ValueError(f"Could not load model {model.model} from any source.")
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Any, Generic, Iterable, Sequence, Type, TypeVar
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import onnxruntime as ort
|
|
8
|
+
|
|
9
|
+
from numpy.typing import NDArray
|
|
10
|
+
from tokenizers import Tokenizer
|
|
11
|
+
|
|
12
|
+
from fastembed.common.types import OnnxProvider, NumpyArray, Device
|
|
13
|
+
from fastembed.parallel_processor import Worker
|
|
14
|
+
|
|
15
|
+
# Holds type of the embedding result
|
|
16
|
+
T = TypeVar("T")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class OnnxOutputContext:
|
|
21
|
+
model_output: NumpyArray
|
|
22
|
+
attention_mask: NDArray[np.int64] | None = None
|
|
23
|
+
input_ids: NDArray[np.int64] | None = None
|
|
24
|
+
metadata: dict[str, Any] | None = None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class OnnxModel(Generic[T]):
|
|
28
|
+
EXPOSED_SESSION_OPTIONS = ("enable_cpu_mem_arena",)
|
|
29
|
+
|
|
30
|
+
@classmethod
|
|
31
|
+
def _get_worker_class(cls) -> Type["EmbeddingWorker[T]"]:
|
|
32
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
33
|
+
|
|
34
|
+
def _post_process_onnx_output(self, output: OnnxOutputContext, **kwargs: Any) -> Iterable[T]:
|
|
35
|
+
"""Post-process the ONNX model output to convert it into a usable format.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
output (OnnxOutputContext): The raw output from the ONNX model.
|
|
39
|
+
**kwargs: Additional keyword arguments that may be needed by specific implementations.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Iterable[T]: Post-processed output as an iterable of type T.
|
|
43
|
+
"""
|
|
44
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
45
|
+
|
|
46
|
+
def __init__(self) -> None:
|
|
47
|
+
self.model: ort.InferenceSession | None = None
|
|
48
|
+
self.tokenizer: Tokenizer | None = None
|
|
49
|
+
|
|
50
|
+
def _preprocess_onnx_input(
|
|
51
|
+
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
|
|
52
|
+
) -> dict[str, NumpyArray]:
|
|
53
|
+
"""
|
|
54
|
+
Preprocess the onnx input.
|
|
55
|
+
"""
|
|
56
|
+
return onnx_input
|
|
57
|
+
|
|
58
|
+
def _load_onnx_model(
|
|
59
|
+
self,
|
|
60
|
+
model_dir: Path,
|
|
61
|
+
model_file: str,
|
|
62
|
+
threads: int | None,
|
|
63
|
+
providers: Sequence[OnnxProvider] | None = None,
|
|
64
|
+
cuda: bool | Device = Device.AUTO,
|
|
65
|
+
device_id: int | None = None,
|
|
66
|
+
extra_session_options: dict[str, Any] | None = None,
|
|
67
|
+
) -> None:
|
|
68
|
+
model_path = model_dir / model_file
|
|
69
|
+
# List of Execution Providers: https://onnxruntime.ai/docs/execution-providers
|
|
70
|
+
available_providers = ort.get_available_providers()
|
|
71
|
+
cuda_available = "CUDAExecutionProvider" in available_providers
|
|
72
|
+
explicit_cuda = cuda is True or cuda == Device.CUDA
|
|
73
|
+
|
|
74
|
+
if explicit_cuda and providers is not None:
|
|
75
|
+
warnings.warn(
|
|
76
|
+
f"`cuda` and `providers` are mutually exclusive parameters, "
|
|
77
|
+
f"cuda: {cuda}, providers: {providers}. If you'd like to use providers, cuda should be one of "
|
|
78
|
+
f"[False, Device.CPU, Device.AUTO].",
|
|
79
|
+
category=UserWarning,
|
|
80
|
+
stacklevel=6,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
if providers is not None:
|
|
84
|
+
onnx_providers = list(providers)
|
|
85
|
+
elif explicit_cuda or (cuda == Device.AUTO and cuda_available):
|
|
86
|
+
if device_id is None:
|
|
87
|
+
onnx_providers = ["CUDAExecutionProvider"]
|
|
88
|
+
else:
|
|
89
|
+
onnx_providers = [("CUDAExecutionProvider", {"device_id": device_id})]
|
|
90
|
+
else:
|
|
91
|
+
onnx_providers = ["CPUExecutionProvider"]
|
|
92
|
+
|
|
93
|
+
requested_provider_names: list[str] = []
|
|
94
|
+
for provider in onnx_providers:
|
|
95
|
+
# check providers available
|
|
96
|
+
provider_name = provider if isinstance(provider, str) else provider[0]
|
|
97
|
+
requested_provider_names.append(provider_name)
|
|
98
|
+
if provider_name not in available_providers:
|
|
99
|
+
raise ValueError(
|
|
100
|
+
f"Provider {provider_name} is not available. Available providers: {available_providers}"
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
so = ort.SessionOptions()
|
|
104
|
+
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
105
|
+
|
|
106
|
+
if threads is not None:
|
|
107
|
+
so.intra_op_num_threads = threads
|
|
108
|
+
so.inter_op_num_threads = threads
|
|
109
|
+
|
|
110
|
+
if extra_session_options is not None:
|
|
111
|
+
self.add_extra_session_options(so, extra_session_options)
|
|
112
|
+
|
|
113
|
+
self.model = ort.InferenceSession(
|
|
114
|
+
str(model_path), providers=onnx_providers, sess_options=so
|
|
115
|
+
)
|
|
116
|
+
if "CUDAExecutionProvider" in requested_provider_names:
|
|
117
|
+
assert self.model is not None
|
|
118
|
+
current_providers = self.model.get_providers()
|
|
119
|
+
if "CUDAExecutionProvider" not in current_providers:
|
|
120
|
+
warnings.warn(
|
|
121
|
+
f"Attempt to set CUDAExecutionProvider failed. Current providers: {current_providers}."
|
|
122
|
+
"If you are using CUDA 12.x, install onnxruntime-gpu via "
|
|
123
|
+
"`pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/`",
|
|
124
|
+
RuntimeWarning,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
@classmethod
|
|
128
|
+
def _select_exposed_session_options(cls, model_kwargs: dict[str, Any]) -> dict[str, Any]:
|
|
129
|
+
"""A convenience method to select the exposed session options in models
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
model_kwargs (dict[str, Any]): The model kwargs.
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
dict[str, Any]: a dict with filtered exposed session options.
|
|
136
|
+
"""
|
|
137
|
+
return {k: v for k, v in model_kwargs.items() if k in cls.EXPOSED_SESSION_OPTIONS}
|
|
138
|
+
|
|
139
|
+
@classmethod
|
|
140
|
+
def add_extra_session_options(
|
|
141
|
+
cls, session_options: ort.SessionOptions, extra_options: dict[str, Any]
|
|
142
|
+
) -> None:
|
|
143
|
+
"""Add extra session options to the existing options object in-place
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
session_options (ort.SessionOptions): The existing session options object.
|
|
147
|
+
extra_options (dict[str, Any]): The extra session options available in cls.EXPOSED_SESSION_OPTIONS.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
None
|
|
151
|
+
"""
|
|
152
|
+
for option in extra_options:
|
|
153
|
+
assert (
|
|
154
|
+
option in cls.EXPOSED_SESSION_OPTIONS
|
|
155
|
+
), f"{option} is unknown or not exposed (exposed options: {cls.EXPOSED_SESSION_OPTIONS})"
|
|
156
|
+
if "enable_cpu_mem_arena" in extra_options:
|
|
157
|
+
session_options.enable_cpu_mem_arena = extra_options["enable_cpu_mem_arena"]
|
|
158
|
+
|
|
159
|
+
def load_onnx_model(self) -> None:
|
|
160
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
161
|
+
|
|
162
|
+
def onnx_embed(self, *args: Any, **kwargs: Any) -> OnnxOutputContext:
|
|
163
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class EmbeddingWorker(Worker, Generic[T]):
|
|
167
|
+
def init_embedding(
|
|
168
|
+
self,
|
|
169
|
+
model_name: str,
|
|
170
|
+
cache_dir: str,
|
|
171
|
+
**kwargs: Any,
|
|
172
|
+
) -> OnnxModel[T]:
|
|
173
|
+
raise NotImplementedError()
|
|
174
|
+
|
|
175
|
+
def __init__(
|
|
176
|
+
self,
|
|
177
|
+
model_name: str,
|
|
178
|
+
cache_dir: str,
|
|
179
|
+
**kwargs: Any,
|
|
180
|
+
):
|
|
181
|
+
self.model = self.init_embedding(model_name, cache_dir, **kwargs)
|
|
182
|
+
|
|
183
|
+
@classmethod
|
|
184
|
+
def start(cls, model_name: str, cache_dir: str, **kwargs: Any) -> "EmbeddingWorker[T]":
|
|
185
|
+
return cls(model_name=model_name, cache_dir=cache_dir, **kwargs)
|
|
186
|
+
|
|
187
|
+
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
|
|
188
|
+
raise NotImplementedError("Subclasses must implement this method")
|