fastembed-bio 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. fastembed/__init__.py +24 -0
  2. fastembed/bio/__init__.py +3 -0
  3. fastembed/bio/protein_embedding.py +456 -0
  4. fastembed/common/__init__.py +3 -0
  5. fastembed/common/model_description.py +52 -0
  6. fastembed/common/model_management.py +471 -0
  7. fastembed/common/onnx_model.py +188 -0
  8. fastembed/common/preprocessor_utils.py +84 -0
  9. fastembed/common/types.py +27 -0
  10. fastembed/common/utils.py +69 -0
  11. fastembed/embedding.py +24 -0
  12. fastembed/image/__init__.py +3 -0
  13. fastembed/image/image_embedding.py +135 -0
  14. fastembed/image/image_embedding_base.py +55 -0
  15. fastembed/image/onnx_embedding.py +217 -0
  16. fastembed/image/onnx_image_model.py +156 -0
  17. fastembed/image/transform/functional.py +221 -0
  18. fastembed/image/transform/operators.py +499 -0
  19. fastembed/late_interaction/__init__.py +5 -0
  20. fastembed/late_interaction/colbert.py +301 -0
  21. fastembed/late_interaction/jina_colbert.py +58 -0
  22. fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
  23. fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
  24. fastembed/late_interaction/token_embeddings.py +83 -0
  25. fastembed/late_interaction_multimodal/__init__.py +5 -0
  26. fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
  27. fastembed/late_interaction_multimodal/colpali.py +327 -0
  28. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
  29. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
  30. fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
  31. fastembed/parallel_processor.py +253 -0
  32. fastembed/postprocess/__init__.py +3 -0
  33. fastembed/postprocess/muvera.py +362 -0
  34. fastembed/py.typed +1 -0
  35. fastembed/rerank/cross_encoder/__init__.py +3 -0
  36. fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
  37. fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
  38. fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
  39. fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
  40. fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
  41. fastembed/sparse/__init__.py +4 -0
  42. fastembed/sparse/bm25.py +359 -0
  43. fastembed/sparse/bm42.py +369 -0
  44. fastembed/sparse/minicoil.py +372 -0
  45. fastembed/sparse/sparse_embedding_base.py +90 -0
  46. fastembed/sparse/sparse_text_embedding.py +143 -0
  47. fastembed/sparse/splade_pp.py +196 -0
  48. fastembed/sparse/utils/minicoil_encoder.py +146 -0
  49. fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
  50. fastembed/sparse/utils/tokenizer.py +120 -0
  51. fastembed/sparse/utils/vocab_resolver.py +202 -0
  52. fastembed/text/__init__.py +3 -0
  53. fastembed/text/clip_embedding.py +56 -0
  54. fastembed/text/custom_text_embedding.py +97 -0
  55. fastembed/text/multitask_embedding.py +109 -0
  56. fastembed/text/onnx_embedding.py +353 -0
  57. fastembed/text/onnx_text_model.py +180 -0
  58. fastembed/text/pooled_embedding.py +136 -0
  59. fastembed/text/pooled_normalized_embedding.py +164 -0
  60. fastembed/text/text_embedding.py +228 -0
  61. fastembed/text/text_embedding_base.py +75 -0
  62. fastembed_bio-0.1.0.dist-info/METADATA +339 -0
  63. fastembed_bio-0.1.0.dist-info/RECORD +66 -0
  64. fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
  65. fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
  66. fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
@@ -0,0 +1,471 @@
1
+ import os
2
+ import time
3
+ import json
4
+ import shutil
5
+ import tarfile
6
+ from copy import deepcopy
7
+ from pathlib import Path
8
+ from typing import Any, TypeVar, Generic
9
+
10
+ import requests
11
+ from huggingface_hub import snapshot_download, model_info, list_repo_tree
12
+ from huggingface_hub.hf_api import RepoFile
13
+ from huggingface_hub.utils import (
14
+ RepositoryNotFoundError,
15
+ disable_progress_bars,
16
+ enable_progress_bars,
17
+ )
18
+ from loguru import logger
19
+ from tqdm import tqdm
20
+ from fastembed.common.model_description import BaseModelDescription
21
+
22
+ T = TypeVar("T", bound=BaseModelDescription)
23
+
24
+
25
+ class ModelManagement(Generic[T]):
26
+ METADATA_FILE = "files_metadata.json"
27
+
28
+ @classmethod
29
+ def list_supported_models(cls) -> list[dict[str, Any]]:
30
+ """Lists the supported models.
31
+
32
+ Returns:
33
+ list[T]: A list of dictionaries containing the model information.
34
+ """
35
+ raise NotImplementedError()
36
+
37
+ @classmethod
38
+ def add_custom_model(
39
+ cls,
40
+ *args: Any,
41
+ **kwargs: Any,
42
+ ) -> None:
43
+ """Add a custom model to the existing embedding classes based on the passed model descriptions
44
+
45
+ Model description dict should contain the fields same as in one of the model descriptions presented
46
+ in fastembed.common.model_description
47
+
48
+ E.g. for BaseModelDescription:
49
+ model: str
50
+ sources: ModelSource
51
+ model_file: str
52
+ description: str
53
+ license: str
54
+ size_in_GB: float
55
+ additional_files: list[str]
56
+
57
+ Returns:
58
+ None
59
+ """
60
+ raise NotImplementedError()
61
+
62
+ @classmethod
63
+ def _list_supported_models(cls) -> list[T]:
64
+ raise NotImplementedError()
65
+
66
+ @classmethod
67
+ def _get_model_description(cls, model_name: str) -> T:
68
+ """
69
+ Gets the model description from the model_name.
70
+
71
+ Args:
72
+ model_name (str): The name of the model.
73
+
74
+ raises:
75
+ ValueError: If the model_name is not supported.
76
+
77
+ Returns:
78
+ T: The model description.
79
+ """
80
+ for model in cls._list_supported_models():
81
+ if model_name.lower() == model.model.lower():
82
+ return model
83
+
84
+ raise ValueError(f"Model {model_name} is not supported in {cls.__name__}.")
85
+
86
+ @classmethod
87
+ def download_file_from_gcs(cls, url: str, output_path: str, show_progress: bool = True) -> str:
88
+ """
89
+ Downloads a file from Google Cloud Storage.
90
+
91
+ Args:
92
+ url (str): The URL to download the file from.
93
+ output_path (str): The path to save the downloaded file to.
94
+ show_progress (bool, optional): Whether to show a progress bar. Defaults to True.
95
+
96
+ Returns:
97
+ str: The path to the downloaded file.
98
+ """
99
+
100
+ if os.path.exists(output_path):
101
+ return output_path
102
+ response = requests.get(url, stream=True)
103
+
104
+ # Handle HTTP errors
105
+ if response.status_code == 403:
106
+ raise PermissionError(
107
+ "Authentication Error: You do not have permission to access this resource. "
108
+ "Please check your credentials."
109
+ )
110
+
111
+ # Get the total size of the file
112
+ total_size_in_bytes = int(response.headers.get("content-length", 0))
113
+
114
+ # Warn if the total size is zero
115
+ if total_size_in_bytes == 0:
116
+ print(f"Warning: Content-length header is missing or zero in the response from {url}.")
117
+
118
+ show_progress = bool(total_size_in_bytes and show_progress)
119
+
120
+ with tqdm(
121
+ total=total_size_in_bytes,
122
+ unit="iB",
123
+ unit_scale=True,
124
+ disable=not show_progress,
125
+ ) as progress_bar:
126
+ with open(output_path, "wb") as file:
127
+ for chunk in response.iter_content(chunk_size=1024):
128
+ if chunk: # Filter out keep-alive new chunks
129
+ progress_bar.update(len(chunk))
130
+ file.write(chunk)
131
+ return output_path
132
+
133
+ @classmethod
134
+ def download_files_from_huggingface(
135
+ cls,
136
+ hf_source_repo: str,
137
+ cache_dir: str,
138
+ extra_patterns: list[str],
139
+ local_files_only: bool = False,
140
+ **kwargs: Any,
141
+ ) -> str:
142
+ """
143
+ Downloads a model from HuggingFace Hub.
144
+ Args:
145
+ hf_source_repo (str): Name of the model on HuggingFace Hub, e.g. "qdrant/all-MiniLM-L6-v2-onnx".
146
+ cache_dir (Optional[str]): The path to the cache directory.
147
+ extra_patterns (list[str]): extra patterns to allow in the snapshot download, typically
148
+ includes the required model files.
149
+ local_files_only (bool, optional): Whether to only use local files. Defaults to False.
150
+ Returns:
151
+ Path: The path to the model directory.
152
+ """
153
+
154
+ def _verify_files_from_metadata(
155
+ model_dir: Path, stored_metadata: dict[str, Any], repo_files: list[RepoFile]
156
+ ) -> bool:
157
+ try:
158
+ for rel_path, meta in stored_metadata.items():
159
+ file_path = model_dir / rel_path
160
+
161
+ if not file_path.exists():
162
+ return False
163
+
164
+ if repo_files: # online verification
165
+ file_info = next((f for f in repo_files if f.path == file_path.name), None)
166
+ if (
167
+ not file_info
168
+ or file_info.size != meta["size"]
169
+ or file_info.blob_id != meta["blob_id"]
170
+ ):
171
+ return False
172
+
173
+ else: # offline verification
174
+ if file_path.stat().st_size != meta["size"]:
175
+ return False
176
+ return True
177
+ except (OSError, KeyError) as e:
178
+ logger.error(f"Error verifying files: {str(e)}")
179
+ return False
180
+
181
+ def _collect_file_metadata(
182
+ model_dir: Path, repo_files: list[RepoFile]
183
+ ) -> dict[str, dict[str, int | str]]:
184
+ meta: dict[str, dict[str, int | str]] = {}
185
+ file_info_map = {f.path: f for f in repo_files}
186
+ for file_path in model_dir.rglob("*"):
187
+ if file_path.is_file() and file_path.name != cls.METADATA_FILE:
188
+ repo_file = file_info_map.get(file_path.name)
189
+ if repo_file:
190
+ meta[str(file_path.relative_to(model_dir))] = {
191
+ "size": repo_file.size,
192
+ "blob_id": repo_file.blob_id,
193
+ }
194
+ return meta
195
+
196
+ def _save_file_metadata(model_dir: Path, meta: dict[str, dict[str, int | str]]) -> None:
197
+ try:
198
+ if not model_dir.exists():
199
+ model_dir.mkdir(parents=True, exist_ok=True)
200
+ (model_dir / cls.METADATA_FILE).write_text(json.dumps(meta))
201
+ except (OSError, ValueError) as e:
202
+ logger.warning(f"Error saving metadata: {str(e)}")
203
+
204
+ allow_patterns = [
205
+ "config.json",
206
+ "tokenizer.json",
207
+ "tokenizer_config.json",
208
+ "special_tokens_map.json",
209
+ "preprocessor_config.json",
210
+ ]
211
+
212
+ allow_patterns.extend(extra_patterns)
213
+
214
+ snapshot_dir = Path(cache_dir) / f"models--{hf_source_repo.replace('/', '--')}"
215
+ metadata_file = snapshot_dir / cls.METADATA_FILE
216
+
217
+ if local_files_only:
218
+ disable_progress_bars()
219
+ if metadata_file.exists():
220
+ metadata = json.loads(metadata_file.read_text())
221
+ verified = _verify_files_from_metadata(snapshot_dir, metadata, repo_files=[])
222
+ if not verified:
223
+ logger.warning(
224
+ "Local file sizes do not match the metadata."
225
+ ) # do not raise, still make an attempt to load the model
226
+ result = snapshot_download(
227
+ repo_id=hf_source_repo,
228
+ allow_patterns=allow_patterns,
229
+ cache_dir=cache_dir,
230
+ local_files_only=local_files_only,
231
+ **kwargs,
232
+ )
233
+ return result
234
+
235
+ repo_revision = model_info(hf_source_repo).sha
236
+ repo_tree = list(list_repo_tree(hf_source_repo, revision=repo_revision, repo_type="model"))
237
+
238
+ allowed_extensions = {".json", ".onnx", ".txt"}
239
+ repo_files = (
240
+ [
241
+ f
242
+ for f in repo_tree
243
+ if isinstance(f, RepoFile) and Path(f.path).suffix in allowed_extensions
244
+ ]
245
+ if repo_tree
246
+ else []
247
+ )
248
+
249
+ verified_metadata = False
250
+
251
+ if snapshot_dir.exists() and metadata_file.exists():
252
+ metadata = json.loads(metadata_file.read_text())
253
+ verified_metadata = _verify_files_from_metadata(snapshot_dir, metadata, repo_files)
254
+
255
+ if verified_metadata:
256
+ disable_progress_bars()
257
+
258
+ result = snapshot_download(
259
+ repo_id=hf_source_repo,
260
+ allow_patterns=allow_patterns,
261
+ cache_dir=cache_dir,
262
+ local_files_only=local_files_only,
263
+ **kwargs,
264
+ )
265
+
266
+ if (
267
+ not verified_metadata
268
+ ): # metadata is not up-to-date, update it and check whether the files have been
269
+ # downloaded correctly
270
+ metadata = _collect_file_metadata(snapshot_dir, repo_files)
271
+
272
+ download_successful = _verify_files_from_metadata(
273
+ snapshot_dir, metadata, repo_files=[]
274
+ ) # offline verification
275
+ if not download_successful:
276
+ raise ValueError(
277
+ "Files have been corrupted during downloading process. "
278
+ "Please check your internet connection and try again."
279
+ )
280
+ _save_file_metadata(snapshot_dir, metadata)
281
+
282
+ return result
283
+
284
+ @classmethod
285
+ def decompress_to_cache(cls, targz_path: str, cache_dir: str) -> str:
286
+ """
287
+ Decompresses a .tar.gz file to a cache directory.
288
+
289
+ Args:
290
+ targz_path (str): Path to the .tar.gz file.
291
+ cache_dir (str): Path to the cache directory.
292
+
293
+ Returns:
294
+ cache_dir (str): Path to the cache directory.
295
+ """
296
+ # Check if targz_path exists and is a file
297
+ if not os.path.isfile(targz_path):
298
+ raise ValueError(f"{targz_path} does not exist or is not a file.")
299
+
300
+ # Check if targz_path is a .tar.gz file
301
+ if not targz_path.endswith(".tar.gz"):
302
+ raise ValueError(f"{targz_path} is not a .tar.gz file.")
303
+
304
+ try:
305
+ # Open the tar.gz file
306
+ with tarfile.open(targz_path, "r:gz") as tar:
307
+ # Extract all files into the cache directory
308
+ tar.extractall(
309
+ path=cache_dir,
310
+ )
311
+ except tarfile.TarError as e:
312
+ # If any error occurs while opening or extracting the tar.gz file,
313
+ # delete the cache directory (if it was created in this function)
314
+ # and raise the error again
315
+ if "tmp" in cache_dir:
316
+ shutil.rmtree(cache_dir)
317
+ raise ValueError(f"An error occurred while decompressing {targz_path}: {e}")
318
+
319
+ return cache_dir
320
+
321
+ @classmethod
322
+ def retrieve_model_gcs(
323
+ cls,
324
+ model_name: str,
325
+ source_url: str,
326
+ cache_dir: str,
327
+ deprecated_tar_struct: bool = False,
328
+ local_files_only: bool = False,
329
+ ) -> Path:
330
+ fast_model_name = f"{'fast-' if deprecated_tar_struct else ''}{model_name.split('/')[-1]}"
331
+ cache_tmp_dir = Path(cache_dir) / "tmp"
332
+ model_tmp_dir = cache_tmp_dir / fast_model_name
333
+ model_dir = Path(cache_dir) / fast_model_name
334
+
335
+ # check if the model_dir and the model files are both present for macOS
336
+ if model_dir.exists() and len(list(model_dir.glob("*"))) > 0:
337
+ return model_dir
338
+
339
+ if model_tmp_dir.exists():
340
+ shutil.rmtree(model_tmp_dir)
341
+
342
+ cache_tmp_dir.mkdir(parents=True, exist_ok=True)
343
+
344
+ model_tar_gz = Path(cache_dir) / f"{fast_model_name}.tar.gz"
345
+
346
+ if model_tar_gz.exists():
347
+ model_tar_gz.unlink()
348
+
349
+ if not local_files_only:
350
+ cls.download_file_from_gcs(
351
+ source_url,
352
+ output_path=str(model_tar_gz),
353
+ )
354
+
355
+ cls.decompress_to_cache(targz_path=str(model_tar_gz), cache_dir=str(cache_tmp_dir))
356
+ assert model_tmp_dir.exists(), f"Could not find {model_tmp_dir} in {cache_tmp_dir}"
357
+
358
+ model_tar_gz.unlink()
359
+ # Rename from tmp to final name is atomic
360
+ model_tmp_dir.rename(model_dir)
361
+ else:
362
+ logger.error(
363
+ f"Could not find the model tar.gz file at {model_dir} and local_files_only=True."
364
+ )
365
+ raise ValueError(
366
+ f"Could not find the model tar.gz file at {model_dir} and local_files_only=True."
367
+ )
368
+
369
+ return model_dir
370
+
371
+ @classmethod
372
+ def download_model(cls, model: T, cache_dir: str, retries: int = 3, **kwargs: Any) -> Path:
373
+ """
374
+ Downloads a model from HuggingFace Hub or Google Cloud Storage.
375
+
376
+ Args:
377
+ model (T): The model description.
378
+ Example:
379
+ ```
380
+ {
381
+ "model": "BAAI/bge-base-en-v1.5",
382
+ "dim": 768,
383
+ "description": "Base English model, v1.5",
384
+ "size_in_GB": 0.44,
385
+ "sources": {
386
+ "url": "https://storage.googleapis.com/qdrant-fastembed/fast-bge-base-en-v1.5.tar.gz",
387
+ "hf": "qdrant/bge-base-en-v1.5-onnx-q",
388
+ }
389
+ }
390
+ ```
391
+ cache_dir (str): The path to the cache directory.
392
+ retries: (int): The number of times to retry (including the first attempt)
393
+
394
+ Returns:
395
+ Path: The path to the downloaded model directory.
396
+ """
397
+ local_files_only = kwargs.get("local_files_only", False)
398
+ specific_model_path: str | None = kwargs.pop("specific_model_path", None)
399
+ if specific_model_path:
400
+ return Path(specific_model_path)
401
+ retries = 1 if local_files_only else retries
402
+ hf_source = model.sources.hf
403
+ url_source = model.sources.url
404
+
405
+ extra_patterns = [model.model_file]
406
+ extra_patterns.extend(model.additional_files)
407
+
408
+ if hf_source:
409
+ try:
410
+ cache_kwargs = deepcopy(kwargs)
411
+ cache_kwargs["local_files_only"] = True
412
+ return Path(
413
+ cls.download_files_from_huggingface(
414
+ hf_source,
415
+ cache_dir=cache_dir,
416
+ extra_patterns=extra_patterns,
417
+ **cache_kwargs,
418
+ )
419
+ )
420
+ except Exception:
421
+ pass
422
+ finally:
423
+ enable_progress_bars()
424
+
425
+ sleep = 3.0
426
+ while retries > 0:
427
+ retries -= 1
428
+
429
+ if hf_source and not local_files_only:
430
+ # we have already tried loading with `local_files_only=True` via hf and we failed
431
+ try:
432
+ return Path(
433
+ cls.download_files_from_huggingface(
434
+ hf_source,
435
+ cache_dir=cache_dir,
436
+ extra_patterns=extra_patterns,
437
+ **kwargs,
438
+ )
439
+ )
440
+ except (EnvironmentError, RepositoryNotFoundError, ValueError) as e:
441
+ if not local_files_only:
442
+ logger.error(
443
+ f"Could not download model from HuggingFace: {e} "
444
+ "Falling back to other sources."
445
+ )
446
+ finally:
447
+ enable_progress_bars()
448
+ if url_source or local_files_only:
449
+ try:
450
+ return cls.retrieve_model_gcs(
451
+ model.model,
452
+ str(url_source),
453
+ str(cache_dir),
454
+ deprecated_tar_struct=model.sources.deprecated_tar_struct,
455
+ local_files_only=local_files_only,
456
+ )
457
+ except Exception:
458
+ if not local_files_only:
459
+ logger.error(f"Could not download model from url: {url_source}")
460
+
461
+ if local_files_only:
462
+ logger.error("Could not find model in cache_dir")
463
+ break
464
+ else:
465
+ logger.error(
466
+ f"Could not download model from either source, sleeping for {sleep} seconds, {retries} retries left."
467
+ )
468
+ time.sleep(sleep)
469
+ sleep *= 3
470
+
471
+ raise ValueError(f"Could not load model {model.model} from any source.")
@@ -0,0 +1,188 @@
1
+ import warnings
2
+ from dataclasses import dataclass
3
+ from pathlib import Path
4
+ from typing import Any, Generic, Iterable, Sequence, Type, TypeVar
5
+
6
+ import numpy as np
7
+ import onnxruntime as ort
8
+
9
+ from numpy.typing import NDArray
10
+ from tokenizers import Tokenizer
11
+
12
+ from fastembed.common.types import OnnxProvider, NumpyArray, Device
13
+ from fastembed.parallel_processor import Worker
14
+
15
+ # Holds type of the embedding result
16
+ T = TypeVar("T")
17
+
18
+
19
+ @dataclass
20
+ class OnnxOutputContext:
21
+ model_output: NumpyArray
22
+ attention_mask: NDArray[np.int64] | None = None
23
+ input_ids: NDArray[np.int64] | None = None
24
+ metadata: dict[str, Any] | None = None
25
+
26
+
27
+ class OnnxModel(Generic[T]):
28
+ EXPOSED_SESSION_OPTIONS = ("enable_cpu_mem_arena",)
29
+
30
+ @classmethod
31
+ def _get_worker_class(cls) -> Type["EmbeddingWorker[T]"]:
32
+ raise NotImplementedError("Subclasses must implement this method")
33
+
34
+ def _post_process_onnx_output(self, output: OnnxOutputContext, **kwargs: Any) -> Iterable[T]:
35
+ """Post-process the ONNX model output to convert it into a usable format.
36
+
37
+ Args:
38
+ output (OnnxOutputContext): The raw output from the ONNX model.
39
+ **kwargs: Additional keyword arguments that may be needed by specific implementations.
40
+
41
+ Returns:
42
+ Iterable[T]: Post-processed output as an iterable of type T.
43
+ """
44
+ raise NotImplementedError("Subclasses must implement this method")
45
+
46
+ def __init__(self) -> None:
47
+ self.model: ort.InferenceSession | None = None
48
+ self.tokenizer: Tokenizer | None = None
49
+
50
+ def _preprocess_onnx_input(
51
+ self, onnx_input: dict[str, NumpyArray], **kwargs: Any
52
+ ) -> dict[str, NumpyArray]:
53
+ """
54
+ Preprocess the onnx input.
55
+ """
56
+ return onnx_input
57
+
58
+ def _load_onnx_model(
59
+ self,
60
+ model_dir: Path,
61
+ model_file: str,
62
+ threads: int | None,
63
+ providers: Sequence[OnnxProvider] | None = None,
64
+ cuda: bool | Device = Device.AUTO,
65
+ device_id: int | None = None,
66
+ extra_session_options: dict[str, Any] | None = None,
67
+ ) -> None:
68
+ model_path = model_dir / model_file
69
+ # List of Execution Providers: https://onnxruntime.ai/docs/execution-providers
70
+ available_providers = ort.get_available_providers()
71
+ cuda_available = "CUDAExecutionProvider" in available_providers
72
+ explicit_cuda = cuda is True or cuda == Device.CUDA
73
+
74
+ if explicit_cuda and providers is not None:
75
+ warnings.warn(
76
+ f"`cuda` and `providers` are mutually exclusive parameters, "
77
+ f"cuda: {cuda}, providers: {providers}. If you'd like to use providers, cuda should be one of "
78
+ f"[False, Device.CPU, Device.AUTO].",
79
+ category=UserWarning,
80
+ stacklevel=6,
81
+ )
82
+
83
+ if providers is not None:
84
+ onnx_providers = list(providers)
85
+ elif explicit_cuda or (cuda == Device.AUTO and cuda_available):
86
+ if device_id is None:
87
+ onnx_providers = ["CUDAExecutionProvider"]
88
+ else:
89
+ onnx_providers = [("CUDAExecutionProvider", {"device_id": device_id})]
90
+ else:
91
+ onnx_providers = ["CPUExecutionProvider"]
92
+
93
+ requested_provider_names: list[str] = []
94
+ for provider in onnx_providers:
95
+ # check providers available
96
+ provider_name = provider if isinstance(provider, str) else provider[0]
97
+ requested_provider_names.append(provider_name)
98
+ if provider_name not in available_providers:
99
+ raise ValueError(
100
+ f"Provider {provider_name} is not available. Available providers: {available_providers}"
101
+ )
102
+
103
+ so = ort.SessionOptions()
104
+ so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
105
+
106
+ if threads is not None:
107
+ so.intra_op_num_threads = threads
108
+ so.inter_op_num_threads = threads
109
+
110
+ if extra_session_options is not None:
111
+ self.add_extra_session_options(so, extra_session_options)
112
+
113
+ self.model = ort.InferenceSession(
114
+ str(model_path), providers=onnx_providers, sess_options=so
115
+ )
116
+ if "CUDAExecutionProvider" in requested_provider_names:
117
+ assert self.model is not None
118
+ current_providers = self.model.get_providers()
119
+ if "CUDAExecutionProvider" not in current_providers:
120
+ warnings.warn(
121
+ f"Attempt to set CUDAExecutionProvider failed. Current providers: {current_providers}."
122
+ "If you are using CUDA 12.x, install onnxruntime-gpu via "
123
+ "`pip install onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/`",
124
+ RuntimeWarning,
125
+ )
126
+
127
+ @classmethod
128
+ def _select_exposed_session_options(cls, model_kwargs: dict[str, Any]) -> dict[str, Any]:
129
+ """A convenience method to select the exposed session options in models
130
+
131
+ Args:
132
+ model_kwargs (dict[str, Any]): The model kwargs.
133
+
134
+ Returns:
135
+ dict[str, Any]: a dict with filtered exposed session options.
136
+ """
137
+ return {k: v for k, v in model_kwargs.items() if k in cls.EXPOSED_SESSION_OPTIONS}
138
+
139
+ @classmethod
140
+ def add_extra_session_options(
141
+ cls, session_options: ort.SessionOptions, extra_options: dict[str, Any]
142
+ ) -> None:
143
+ """Add extra session options to the existing options object in-place
144
+
145
+ Args:
146
+ session_options (ort.SessionOptions): The existing session options object.
147
+ extra_options (dict[str, Any]): The extra session options available in cls.EXPOSED_SESSION_OPTIONS.
148
+
149
+ Returns:
150
+ None
151
+ """
152
+ for option in extra_options:
153
+ assert (
154
+ option in cls.EXPOSED_SESSION_OPTIONS
155
+ ), f"{option} is unknown or not exposed (exposed options: {cls.EXPOSED_SESSION_OPTIONS})"
156
+ if "enable_cpu_mem_arena" in extra_options:
157
+ session_options.enable_cpu_mem_arena = extra_options["enable_cpu_mem_arena"]
158
+
159
+ def load_onnx_model(self) -> None:
160
+ raise NotImplementedError("Subclasses must implement this method")
161
+
162
+ def onnx_embed(self, *args: Any, **kwargs: Any) -> OnnxOutputContext:
163
+ raise NotImplementedError("Subclasses must implement this method")
164
+
165
+
166
+ class EmbeddingWorker(Worker, Generic[T]):
167
+ def init_embedding(
168
+ self,
169
+ model_name: str,
170
+ cache_dir: str,
171
+ **kwargs: Any,
172
+ ) -> OnnxModel[T]:
173
+ raise NotImplementedError()
174
+
175
+ def __init__(
176
+ self,
177
+ model_name: str,
178
+ cache_dir: str,
179
+ **kwargs: Any,
180
+ ):
181
+ self.model = self.init_embedding(model_name, cache_dir, **kwargs)
182
+
183
+ @classmethod
184
+ def start(cls, model_name: str, cache_dir: str, **kwargs: Any) -> "EmbeddingWorker[T]":
185
+ return cls(model_name=model_name, cache_dir=cache_dir, **kwargs)
186
+
187
+ def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
188
+ raise NotImplementedError("Subclasses must implement this method")