huggingface-api-haystack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- haystack_integrations/components/common/huggingface_api/__init__.py +3 -0
- haystack_integrations/components/common/huggingface_api/utils.py +112 -0
- haystack_integrations/components/common/py.typed +0 -0
- haystack_integrations/components/embedders/huggingface_api/__init__.py +7 -0
- haystack_integrations/components/embedders/huggingface_api/document_embedder.py +382 -0
- haystack_integrations/components/embedders/huggingface_api/text_embedder.py +262 -0
- haystack_integrations/components/embedders/py.typed +0 -0
- haystack_integrations/components/generators/huggingface_api/__init__.py +6 -0
- haystack_integrations/components/generators/huggingface_api/chat/__init__.py +3 -0
- haystack_integrations/components/generators/huggingface_api/chat/chat_generator.py +738 -0
- haystack_integrations/components/generators/py.typed +0 -0
- huggingface_api_haystack-0.1.0.dist-info/METADATA +40 -0
- huggingface_api_haystack-0.1.0.dist-info/RECORD +15 -0
- huggingface_api_haystack-0.1.0.dist-info/WHEEL +4 -0
- huggingface_api_haystack-0.1.0.dist-info/licenses/LICENSE.txt +201 -0
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
from enum import Enum
|
|
6
|
+
|
|
7
|
+
from haystack.utils import Secret
|
|
8
|
+
from huggingface_hub import HfApi
|
|
9
|
+
from huggingface_hub.errors import RepositoryNotFoundError
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class HFGenerationAPIType(Enum):
|
|
13
|
+
"""
|
|
14
|
+
API type to use for Hugging Face API Generators.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
# HF [Text Generation Inference (TGI)](https://github.com/huggingface/text-generation-inference).
|
|
18
|
+
TEXT_GENERATION_INFERENCE = "text_generation_inference"
|
|
19
|
+
|
|
20
|
+
# HF [Inference Endpoints](https://huggingface.co/inference-endpoints).
|
|
21
|
+
INFERENCE_ENDPOINTS = "inference_endpoints"
|
|
22
|
+
|
|
23
|
+
# HF [Serverless Inference API](https://huggingface.co/inference-api).
|
|
24
|
+
SERVERLESS_INFERENCE_API = "serverless_inference_api"
|
|
25
|
+
|
|
26
|
+
def __str__(self) -> str:
|
|
27
|
+
return self.value
|
|
28
|
+
|
|
29
|
+
@staticmethod
|
|
30
|
+
def from_str(string: str) -> "HFGenerationAPIType":
|
|
31
|
+
"""
|
|
32
|
+
Convert a string to a HFGenerationAPIType enum.
|
|
33
|
+
|
|
34
|
+
:param string: The string to convert.
|
|
35
|
+
:return: The corresponding HFGenerationAPIType enum.
|
|
36
|
+
"""
|
|
37
|
+
enum_map = {e.value: e for e in HFGenerationAPIType}
|
|
38
|
+
mode = enum_map.get(string)
|
|
39
|
+
if mode is None:
|
|
40
|
+
msg = f"Unknown Hugging Face API type '{string}'. Supported types are: {list(enum_map.keys())}"
|
|
41
|
+
raise ValueError(msg)
|
|
42
|
+
return mode
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class HFEmbeddingAPIType(Enum):
|
|
46
|
+
"""
|
|
47
|
+
API type to use for Hugging Face API Embedders.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
# HF [Text Embeddings Inference (TEI)](https://github.com/huggingface/text-embeddings-inference).
|
|
51
|
+
TEXT_EMBEDDINGS_INFERENCE = "text_embeddings_inference"
|
|
52
|
+
|
|
53
|
+
# HF [Inference Endpoints](https://huggingface.co/inference-endpoints).
|
|
54
|
+
INFERENCE_ENDPOINTS = "inference_endpoints"
|
|
55
|
+
|
|
56
|
+
# HF [Serverless Inference API](https://huggingface.co/inference-api).
|
|
57
|
+
SERVERLESS_INFERENCE_API = "serverless_inference_api"
|
|
58
|
+
|
|
59
|
+
def __str__(self) -> str:
|
|
60
|
+
return self.value
|
|
61
|
+
|
|
62
|
+
@staticmethod
|
|
63
|
+
def from_str(string: str) -> "HFEmbeddingAPIType":
|
|
64
|
+
"""
|
|
65
|
+
Convert a string to a HFEmbeddingAPIType enum.
|
|
66
|
+
|
|
67
|
+
:param string: The string to convert.
|
|
68
|
+
:return: The corresponding HFEmbeddingAPIType enum.
|
|
69
|
+
"""
|
|
70
|
+
enum_map = {e.value: e for e in HFEmbeddingAPIType}
|
|
71
|
+
mode = enum_map.get(string)
|
|
72
|
+
if mode is None:
|
|
73
|
+
msg = f"Unknown Hugging Face API type '{string}'. Supported types are: {list(enum_map.keys())}"
|
|
74
|
+
raise ValueError(msg)
|
|
75
|
+
return mode
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class HFModelType(Enum):
|
|
79
|
+
EMBEDDING = 1
|
|
80
|
+
GENERATION = 2
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _check_valid_model(model_id: str, model_type: HFModelType, token: Secret | None) -> None:
|
|
84
|
+
"""
|
|
85
|
+
Check if the provided model ID corresponds to a valid model on HuggingFace Hub.
|
|
86
|
+
|
|
87
|
+
Also check if the model is an embedding or generation model.
|
|
88
|
+
|
|
89
|
+
:param model_id: A string representing the HuggingFace model ID.
|
|
90
|
+
:param model_type: the model type, HFModelType.EMBEDDING or HFModelType.GENERATION
|
|
91
|
+
:param token: The optional authentication token.
|
|
92
|
+
:raises ValueError: If the model is not found or is not a embedding model.
|
|
93
|
+
"""
|
|
94
|
+
api = HfApi()
|
|
95
|
+
try:
|
|
96
|
+
model_info = api.model_info(model_id, token=token.resolve_value() if token else None)
|
|
97
|
+
except RepositoryNotFoundError as e:
|
|
98
|
+
msg = f"Model {model_id} not found on HuggingFace Hub. Please provide a valid HuggingFace model_id."
|
|
99
|
+
raise ValueError(msg) from e
|
|
100
|
+
|
|
101
|
+
if model_type == HFModelType.EMBEDDING:
|
|
102
|
+
allowed_model = model_info.pipeline_tag in ["sentence-similarity", "feature-extraction"]
|
|
103
|
+
error_msg = f"Model {model_id} is not a embedding model. Please provide a embedding model."
|
|
104
|
+
elif model_type == HFModelType.GENERATION:
|
|
105
|
+
allowed_model = model_info.pipeline_tag in ["text-generation", "text2text-generation", "image-text-to-text"]
|
|
106
|
+
error_msg = f"Model {model_id} is not a text generation model. Please provide a text generation model."
|
|
107
|
+
else:
|
|
108
|
+
allowed_model = False
|
|
109
|
+
error_msg = f"Unknown model type for {model_id}"
|
|
110
|
+
|
|
111
|
+
if not allowed_model:
|
|
112
|
+
raise ValueError(error_msg)
|
|
File without changes
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
from .document_embedder import HuggingFaceAPIDocumentEmbedder
|
|
5
|
+
from .text_embedder import HuggingFaceAPITextEmbedder
|
|
6
|
+
|
|
7
|
+
__all__ = ["HuggingFaceAPIDocumentEmbedder", "HuggingFaceAPITextEmbedder"]
|
|
@@ -0,0 +1,382 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
from asyncio import Semaphore, gather
|
|
6
|
+
from dataclasses import replace
|
|
7
|
+
from itertools import chain
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from haystack import component, default_from_dict, default_to_dict, logging
|
|
11
|
+
from haystack.dataclasses import Document
|
|
12
|
+
from haystack.utils import Secret
|
|
13
|
+
from haystack.utils.url_validation import is_valid_http_url
|
|
14
|
+
from huggingface_hub import AsyncInferenceClient, InferenceClient
|
|
15
|
+
from tqdm import tqdm
|
|
16
|
+
|
|
17
|
+
from haystack_integrations.components.common.huggingface_api.utils import (
|
|
18
|
+
HFEmbeddingAPIType,
|
|
19
|
+
HFModelType,
|
|
20
|
+
_check_valid_model,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
# batched embeddings returned by the API are expected to have shape (batch_size, embedding_dim)
|
|
26
|
+
_EXPECTED_EMBEDDING_NDIM = 2
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@component
|
|
30
|
+
class HuggingFaceAPIDocumentEmbedder:
|
|
31
|
+
"""
|
|
32
|
+
Embeds documents using Hugging Face APIs.
|
|
33
|
+
|
|
34
|
+
Use it with the following Hugging Face APIs:
|
|
35
|
+
- [Free Serverless Inference API](https://huggingface.co/inference-api)
|
|
36
|
+
- [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
|
|
37
|
+
- [Self-hosted Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
### Usage examples
|
|
41
|
+
|
|
42
|
+
#### With free serverless inference API
|
|
43
|
+
|
|
44
|
+
```python
|
|
45
|
+
from haystack_integrations.components.embedders.huggingface_api import HuggingFaceAPIDocumentEmbedder
|
|
46
|
+
from haystack.utils import Secret
|
|
47
|
+
from haystack.dataclasses import Document
|
|
48
|
+
|
|
49
|
+
doc = Document(content="I love pizza!")
|
|
50
|
+
|
|
51
|
+
doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="serverless_inference_api",
|
|
52
|
+
api_params={"model": "BAAI/bge-small-en-v1.5"},
|
|
53
|
+
token=Secret.from_token("<your-api-key>"))
|
|
54
|
+
|
|
55
|
+
result = document_embedder.run([doc])
|
|
56
|
+
print(result["documents"][0].embedding)
|
|
57
|
+
|
|
58
|
+
# [0.017020374536514282, -0.023255806416273117, ...]
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
#### With paid inference endpoints
|
|
62
|
+
|
|
63
|
+
```python
|
|
64
|
+
from haystack_integrations.components.embedders.huggingface_api import HuggingFaceAPIDocumentEmbedder
|
|
65
|
+
from haystack.utils import Secret
|
|
66
|
+
from haystack.dataclasses import Document
|
|
67
|
+
|
|
68
|
+
doc = Document(content="I love pizza!")
|
|
69
|
+
|
|
70
|
+
doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="inference_endpoints",
|
|
71
|
+
api_params={"url": "<your-inference-endpoint-url>"},
|
|
72
|
+
token=Secret.from_token("<your-api-key>"))
|
|
73
|
+
|
|
74
|
+
result = document_embedder.run([doc])
|
|
75
|
+
print(result["documents"][0].embedding)
|
|
76
|
+
|
|
77
|
+
# [0.017020374536514282, -0.023255806416273117, ...]
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
#### With self-hosted text embeddings inference
|
|
81
|
+
|
|
82
|
+
```python
|
|
83
|
+
from haystack_integrations.components.embedders.huggingface_api import HuggingFaceAPIDocumentEmbedder
|
|
84
|
+
from haystack.dataclasses import Document
|
|
85
|
+
|
|
86
|
+
doc = Document(content="I love pizza!")
|
|
87
|
+
|
|
88
|
+
doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="text_embeddings_inference",
|
|
89
|
+
api_params={"url": "http://localhost:8080"})
|
|
90
|
+
|
|
91
|
+
result = document_embedder.run([doc])
|
|
92
|
+
print(result["documents"][0].embedding)
|
|
93
|
+
|
|
94
|
+
# [0.017020374536514282, -0.023255806416273117, ...]
|
|
95
|
+
```
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
api_type: HFEmbeddingAPIType | str,
|
|
101
|
+
api_params: dict[str, str],
|
|
102
|
+
token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
|
|
103
|
+
prefix: str = "",
|
|
104
|
+
suffix: str = "",
|
|
105
|
+
truncate: bool | None = True,
|
|
106
|
+
normalize: bool | None = False,
|
|
107
|
+
batch_size: int = 32,
|
|
108
|
+
progress_bar: bool = True,
|
|
109
|
+
meta_fields_to_embed: list[str] | None = None,
|
|
110
|
+
embedding_separator: str = "\n",
|
|
111
|
+
concurrency_limit: int = 4,
|
|
112
|
+
) -> None:
|
|
113
|
+
"""
|
|
114
|
+
Creates a HuggingFaceAPIDocumentEmbedder component.
|
|
115
|
+
|
|
116
|
+
:param api_type:
|
|
117
|
+
The type of Hugging Face API to use.
|
|
118
|
+
:param api_params:
|
|
119
|
+
A dictionary with the following keys:
|
|
120
|
+
- `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
|
|
121
|
+
- `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
|
|
122
|
+
`TEXT_EMBEDDINGS_INFERENCE`.
|
|
123
|
+
:param token: The Hugging Face token to use as HTTP bearer authorization.
|
|
124
|
+
Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
|
|
125
|
+
:param prefix:
|
|
126
|
+
A string to add at the beginning of each text.
|
|
127
|
+
:param suffix:
|
|
128
|
+
A string to add at the end of each text.
|
|
129
|
+
:param truncate:
|
|
130
|
+
Truncates the input text to the maximum length supported by the model.
|
|
131
|
+
Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
|
|
132
|
+
if the backend uses Text Embeddings Inference.
|
|
133
|
+
If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
|
|
134
|
+
:param normalize:
|
|
135
|
+
Normalizes the embeddings to unit length.
|
|
136
|
+
Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
|
|
137
|
+
if the backend uses Text Embeddings Inference.
|
|
138
|
+
If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
|
|
139
|
+
:param batch_size:
|
|
140
|
+
Number of documents to process at once.
|
|
141
|
+
:param progress_bar:
|
|
142
|
+
If `True`, shows a progress bar when running.
|
|
143
|
+
:param meta_fields_to_embed:
|
|
144
|
+
List of metadata fields to embed along with the document text.
|
|
145
|
+
:param embedding_separator:
|
|
146
|
+
Separator used to concatenate the metadata fields to the document text.
|
|
147
|
+
:param concurrency_limit:
|
|
148
|
+
The maximum number of requests that should be allowed to run concurrently.
|
|
149
|
+
This parameter is only used in the `run_async` method.
|
|
150
|
+
"""
|
|
151
|
+
if isinstance(api_type, str):
|
|
152
|
+
api_type = HFEmbeddingAPIType.from_str(api_type)
|
|
153
|
+
|
|
154
|
+
api_params = api_params or {}
|
|
155
|
+
|
|
156
|
+
if api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
|
|
157
|
+
model = api_params.get("model")
|
|
158
|
+
if model is None:
|
|
159
|
+
msg = "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
|
|
160
|
+
raise ValueError(msg)
|
|
161
|
+
_check_valid_model(model, HFModelType.EMBEDDING, token)
|
|
162
|
+
model_or_url = model
|
|
163
|
+
elif api_type in [HFEmbeddingAPIType.INFERENCE_ENDPOINTS, HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE]:
|
|
164
|
+
url = api_params.get("url")
|
|
165
|
+
if url is None:
|
|
166
|
+
msg = (
|
|
167
|
+
"To use Text Embeddings Inference or Inference Endpoints, you need to specify the `url` "
|
|
168
|
+
"parameter in `api_params`."
|
|
169
|
+
)
|
|
170
|
+
raise ValueError(msg)
|
|
171
|
+
if not is_valid_http_url(url):
|
|
172
|
+
msg = f"Invalid URL: {url}"
|
|
173
|
+
raise ValueError(msg)
|
|
174
|
+
model_or_url = url
|
|
175
|
+
else:
|
|
176
|
+
msg = f"Unknown api_type {api_type}"
|
|
177
|
+
raise ValueError(msg)
|
|
178
|
+
|
|
179
|
+
client_args: dict[str, Any] = {"model": model_or_url, "token": token.resolve_value() if token else None}
|
|
180
|
+
|
|
181
|
+
self.api_type = api_type
|
|
182
|
+
self.api_params = api_params
|
|
183
|
+
self.token = token
|
|
184
|
+
self.prefix = prefix
|
|
185
|
+
self.suffix = suffix
|
|
186
|
+
self.truncate = truncate
|
|
187
|
+
self.normalize = normalize
|
|
188
|
+
self.batch_size = batch_size
|
|
189
|
+
self.progress_bar = progress_bar
|
|
190
|
+
self.meta_fields_to_embed = meta_fields_to_embed or []
|
|
191
|
+
self.embedding_separator = embedding_separator
|
|
192
|
+
self.concurrency_limit = concurrency_limit
|
|
193
|
+
self._client = InferenceClient(**client_args)
|
|
194
|
+
self._async_client = AsyncInferenceClient(**client_args)
|
|
195
|
+
|
|
196
|
+
def to_dict(self) -> dict[str, Any]:
|
|
197
|
+
"""
|
|
198
|
+
Serializes the component to a dictionary.
|
|
199
|
+
|
|
200
|
+
:returns:
|
|
201
|
+
Dictionary with serialized data.
|
|
202
|
+
"""
|
|
203
|
+
return default_to_dict(
|
|
204
|
+
self,
|
|
205
|
+
api_type=str(self.api_type),
|
|
206
|
+
api_params=self.api_params,
|
|
207
|
+
prefix=self.prefix,
|
|
208
|
+
suffix=self.suffix,
|
|
209
|
+
token=self.token,
|
|
210
|
+
truncate=self.truncate,
|
|
211
|
+
normalize=self.normalize,
|
|
212
|
+
batch_size=self.batch_size,
|
|
213
|
+
progress_bar=self.progress_bar,
|
|
214
|
+
meta_fields_to_embed=self.meta_fields_to_embed,
|
|
215
|
+
embedding_separator=self.embedding_separator,
|
|
216
|
+
concurrency_limit=self.concurrency_limit,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
@classmethod
|
|
220
|
+
def from_dict(cls, data: dict[str, Any]) -> "HuggingFaceAPIDocumentEmbedder":
|
|
221
|
+
"""
|
|
222
|
+
Deserializes the component from a dictionary.
|
|
223
|
+
|
|
224
|
+
:param data:
|
|
225
|
+
Dictionary to deserialize from.
|
|
226
|
+
:returns:
|
|
227
|
+
Deserialized component.
|
|
228
|
+
"""
|
|
229
|
+
return default_from_dict(cls, data)
|
|
230
|
+
|
|
231
|
+
def _prepare_texts_to_embed(self, documents: list[Document]) -> list[str]:
|
|
232
|
+
"""
|
|
233
|
+
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
|
|
234
|
+
"""
|
|
235
|
+
texts_to_embed = []
|
|
236
|
+
for doc in documents:
|
|
237
|
+
meta_values_to_embed = [
|
|
238
|
+
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
|
|
239
|
+
]
|
|
240
|
+
|
|
241
|
+
text_to_embed = (
|
|
242
|
+
self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
texts_to_embed.append(text_to_embed)
|
|
246
|
+
return texts_to_embed
|
|
247
|
+
|
|
248
|
+
@staticmethod
|
|
249
|
+
def _adjust_api_parameters(
|
|
250
|
+
truncate: bool | None, normalize: bool | None, api_type: HFEmbeddingAPIType
|
|
251
|
+
) -> tuple[bool | None, bool | None]:
|
|
252
|
+
"""
|
|
253
|
+
Adjust the truncate and normalize parameters based on the API type.
|
|
254
|
+
"""
|
|
255
|
+
if api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
|
|
256
|
+
if truncate is not None:
|
|
257
|
+
msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored."
|
|
258
|
+
logger.warning(msg)
|
|
259
|
+
truncate = None
|
|
260
|
+
if normalize is not None:
|
|
261
|
+
msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored."
|
|
262
|
+
logger.warning(msg)
|
|
263
|
+
normalize = None
|
|
264
|
+
return truncate, normalize
|
|
265
|
+
|
|
266
|
+
def _embed_batch(self, texts_to_embed: list[str], batch_size: int) -> list[list[float]]:
|
|
267
|
+
"""
|
|
268
|
+
Embed a list of texts in batches.
|
|
269
|
+
"""
|
|
270
|
+
truncate, normalize = self._adjust_api_parameters(self.truncate, self.normalize, self.api_type)
|
|
271
|
+
|
|
272
|
+
all_embeddings: list = []
|
|
273
|
+
for i in tqdm(
|
|
274
|
+
range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
|
|
275
|
+
):
|
|
276
|
+
batch = texts_to_embed[i : i + batch_size]
|
|
277
|
+
|
|
278
|
+
np_embeddings = self._client.feature_extraction(text=batch, truncate=truncate, normalize=normalize)
|
|
279
|
+
|
|
280
|
+
if np_embeddings.ndim != _EXPECTED_EMBEDDING_NDIM or np_embeddings.shape[0] != len(batch):
|
|
281
|
+
msg = f"Expected embedding shape ({batch_size}, embedding_dim), got {np_embeddings.shape}"
|
|
282
|
+
raise ValueError(msg)
|
|
283
|
+
|
|
284
|
+
all_embeddings.extend(np_embeddings.tolist())
|
|
285
|
+
|
|
286
|
+
return all_embeddings
|
|
287
|
+
|
|
288
|
+
async def _embed_batch_async(self, texts_to_embed: list[str], batch_size: int) -> list[list[float]]:
|
|
289
|
+
"""
|
|
290
|
+
Embed a list of texts in batches asynchronously.
|
|
291
|
+
"""
|
|
292
|
+
truncate, normalize = self._adjust_api_parameters(self.truncate, self.normalize, self.api_type)
|
|
293
|
+
sem = Semaphore(max(1, self.concurrency_limit))
|
|
294
|
+
num_batches = (len(texts_to_embed) + batch_size - 1) // batch_size
|
|
295
|
+
pbar = tqdm(total=num_batches, disable=not self.progress_bar, desc="Calculating embeddings")
|
|
296
|
+
|
|
297
|
+
async def _runner(batch: list[str]) -> list[list[float]]:
|
|
298
|
+
async with sem:
|
|
299
|
+
np_embeddings = await self._async_client.feature_extraction(
|
|
300
|
+
text=batch, truncate=truncate, normalize=normalize
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
if np_embeddings.ndim != _EXPECTED_EMBEDDING_NDIM or np_embeddings.shape[0] != len(batch):
|
|
304
|
+
msg = f"Expected embedding shape ({batch_size}, embedding_dim), got {np_embeddings.shape}"
|
|
305
|
+
raise ValueError(msg)
|
|
306
|
+
|
|
307
|
+
pbar.update(1)
|
|
308
|
+
return np_embeddings.tolist()
|
|
309
|
+
|
|
310
|
+
try:
|
|
311
|
+
all_embeddings = [
|
|
312
|
+
*chain(
|
|
313
|
+
*await gather(
|
|
314
|
+
*[
|
|
315
|
+
_runner(texts_to_embed[i : i + batch_size])
|
|
316
|
+
for i in range(0, len(texts_to_embed), batch_size)
|
|
317
|
+
]
|
|
318
|
+
)
|
|
319
|
+
)
|
|
320
|
+
]
|
|
321
|
+
finally:
|
|
322
|
+
pbar.close()
|
|
323
|
+
|
|
324
|
+
return all_embeddings
|
|
325
|
+
|
|
326
|
+
@component.output_types(documents=list[Document])
|
|
327
|
+
def run(self, documents: list[Document]) -> dict[str, list[Document]]:
|
|
328
|
+
"""
|
|
329
|
+
Embeds a list of documents.
|
|
330
|
+
|
|
331
|
+
:param documents:
|
|
332
|
+
Documents to embed.
|
|
333
|
+
|
|
334
|
+
:returns:
|
|
335
|
+
A dictionary with the following keys:
|
|
336
|
+
- `documents`: A list of documents with embeddings.
|
|
337
|
+
"""
|
|
338
|
+
if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
|
|
339
|
+
msg = (
|
|
340
|
+
"HuggingFaceAPIDocumentEmbedder expects a list of Documents as input."
|
|
341
|
+
" In case you want to embed a string, please use the HuggingFaceAPITextEmbedder."
|
|
342
|
+
)
|
|
343
|
+
raise TypeError(msg)
|
|
344
|
+
|
|
345
|
+
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
|
|
346
|
+
|
|
347
|
+
embeddings = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
|
|
348
|
+
|
|
349
|
+
new_documents = []
|
|
350
|
+
for doc, emb in zip(documents, embeddings, strict=True):
|
|
351
|
+
new_documents.append(replace(doc, embedding=emb))
|
|
352
|
+
|
|
353
|
+
return {"documents": new_documents}
|
|
354
|
+
|
|
355
|
+
@component.output_types(documents=list[Document])
|
|
356
|
+
async def run_async(self, documents: list[Document]) -> dict[str, list[Document]]:
|
|
357
|
+
"""
|
|
358
|
+
Embeds a list of documents asynchronously.
|
|
359
|
+
|
|
360
|
+
:param documents:
|
|
361
|
+
Documents to embed.
|
|
362
|
+
|
|
363
|
+
:returns:
|
|
364
|
+
A dictionary with the following keys:
|
|
365
|
+
- `documents`: A list of documents with embeddings.
|
|
366
|
+
"""
|
|
367
|
+
if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
|
|
368
|
+
msg = (
|
|
369
|
+
"HuggingFaceAPIDocumentEmbedder expects a list of Documents as input."
|
|
370
|
+
" In case you want to embed a string, please use the HuggingFaceAPITextEmbedder."
|
|
371
|
+
)
|
|
372
|
+
raise TypeError(msg)
|
|
373
|
+
|
|
374
|
+
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
|
|
375
|
+
|
|
376
|
+
embeddings = await self._embed_batch_async(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
|
|
377
|
+
|
|
378
|
+
new_documents = []
|
|
379
|
+
for doc, emb in zip(documents, embeddings, strict=True):
|
|
380
|
+
new_documents.append(replace(doc, embedding=emb))
|
|
381
|
+
|
|
382
|
+
return {"documents": new_documents}
|