huggingface-api-haystack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3 @@
1
+ # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
@@ -0,0 +1,112 @@
1
+ # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ from enum import Enum
6
+
7
+ from haystack.utils import Secret
8
+ from huggingface_hub import HfApi
9
+ from huggingface_hub.errors import RepositoryNotFoundError
10
+
11
+
12
+ class HFGenerationAPIType(Enum):
13
+ """
14
+ API type to use for Hugging Face API Generators.
15
+ """
16
+
17
+ # HF [Text Generation Inference (TGI)](https://github.com/huggingface/text-generation-inference).
18
+ TEXT_GENERATION_INFERENCE = "text_generation_inference"
19
+
20
+ # HF [Inference Endpoints](https://huggingface.co/inference-endpoints).
21
+ INFERENCE_ENDPOINTS = "inference_endpoints"
22
+
23
+ # HF [Serverless Inference API](https://huggingface.co/inference-api).
24
+ SERVERLESS_INFERENCE_API = "serverless_inference_api"
25
+
26
+ def __str__(self) -> str:
27
+ return self.value
28
+
29
+ @staticmethod
30
+ def from_str(string: str) -> "HFGenerationAPIType":
31
+ """
32
+ Convert a string to a HFGenerationAPIType enum.
33
+
34
+ :param string: The string to convert.
35
+ :return: The corresponding HFGenerationAPIType enum.
36
+ """
37
+ enum_map = {e.value: e for e in HFGenerationAPIType}
38
+ mode = enum_map.get(string)
39
+ if mode is None:
40
+ msg = f"Unknown Hugging Face API type '{string}'. Supported types are: {list(enum_map.keys())}"
41
+ raise ValueError(msg)
42
+ return mode
43
+
44
+
45
+ class HFEmbeddingAPIType(Enum):
46
+ """
47
+ API type to use for Hugging Face API Embedders.
48
+ """
49
+
50
+ # HF [Text Embeddings Inference (TEI)](https://github.com/huggingface/text-embeddings-inference).
51
+ TEXT_EMBEDDINGS_INFERENCE = "text_embeddings_inference"
52
+
53
+ # HF [Inference Endpoints](https://huggingface.co/inference-endpoints).
54
+ INFERENCE_ENDPOINTS = "inference_endpoints"
55
+
56
+ # HF [Serverless Inference API](https://huggingface.co/inference-api).
57
+ SERVERLESS_INFERENCE_API = "serverless_inference_api"
58
+
59
+ def __str__(self) -> str:
60
+ return self.value
61
+
62
+ @staticmethod
63
+ def from_str(string: str) -> "HFEmbeddingAPIType":
64
+ """
65
+ Convert a string to a HFEmbeddingAPIType enum.
66
+
67
+ :param string: The string to convert.
68
+ :return: The corresponding HFEmbeddingAPIType enum.
69
+ """
70
+ enum_map = {e.value: e for e in HFEmbeddingAPIType}
71
+ mode = enum_map.get(string)
72
+ if mode is None:
73
+ msg = f"Unknown Hugging Face API type '{string}'. Supported types are: {list(enum_map.keys())}"
74
+ raise ValueError(msg)
75
+ return mode
76
+
77
+
78
+ class HFModelType(Enum):
79
+ EMBEDDING = 1
80
+ GENERATION = 2
81
+
82
+
83
+ def _check_valid_model(model_id: str, model_type: HFModelType, token: Secret | None) -> None:
84
+ """
85
+ Check if the provided model ID corresponds to a valid model on HuggingFace Hub.
86
+
87
+ Also check if the model is an embedding or generation model.
88
+
89
+ :param model_id: A string representing the HuggingFace model ID.
90
+ :param model_type: the model type, HFModelType.EMBEDDING or HFModelType.GENERATION
91
+ :param token: The optional authentication token.
92
+ :raises ValueError: If the model is not found or is not a embedding model.
93
+ """
94
+ api = HfApi()
95
+ try:
96
+ model_info = api.model_info(model_id, token=token.resolve_value() if token else None)
97
+ except RepositoryNotFoundError as e:
98
+ msg = f"Model {model_id} not found on HuggingFace Hub. Please provide a valid HuggingFace model_id."
99
+ raise ValueError(msg) from e
100
+
101
+ if model_type == HFModelType.EMBEDDING:
102
+ allowed_model = model_info.pipeline_tag in ["sentence-similarity", "feature-extraction"]
103
+ error_msg = f"Model {model_id} is not a embedding model. Please provide a embedding model."
104
+ elif model_type == HFModelType.GENERATION:
105
+ allowed_model = model_info.pipeline_tag in ["text-generation", "text2text-generation", "image-text-to-text"]
106
+ error_msg = f"Model {model_id} is not a text generation model. Please provide a text generation model."
107
+ else:
108
+ allowed_model = False
109
+ error_msg = f"Unknown model type for {model_id}"
110
+
111
+ if not allowed_model:
112
+ raise ValueError(error_msg)
File without changes
@@ -0,0 +1,7 @@
1
+ # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ from .document_embedder import HuggingFaceAPIDocumentEmbedder
5
+ from .text_embedder import HuggingFaceAPITextEmbedder
6
+
7
+ __all__ = ["HuggingFaceAPIDocumentEmbedder", "HuggingFaceAPITextEmbedder"]
@@ -0,0 +1,382 @@
1
+ # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ from asyncio import Semaphore, gather
6
+ from dataclasses import replace
7
+ from itertools import chain
8
+ from typing import Any
9
+
10
+ from haystack import component, default_from_dict, default_to_dict, logging
11
+ from haystack.dataclasses import Document
12
+ from haystack.utils import Secret
13
+ from haystack.utils.url_validation import is_valid_http_url
14
+ from huggingface_hub import AsyncInferenceClient, InferenceClient
15
+ from tqdm import tqdm
16
+
17
+ from haystack_integrations.components.common.huggingface_api.utils import (
18
+ HFEmbeddingAPIType,
19
+ HFModelType,
20
+ _check_valid_model,
21
+ )
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # batched embeddings returned by the API are expected to have shape (batch_size, embedding_dim)
26
+ _EXPECTED_EMBEDDING_NDIM = 2
27
+
28
+
29
+ @component
30
+ class HuggingFaceAPIDocumentEmbedder:
31
+ """
32
+ Embeds documents using Hugging Face APIs.
33
+
34
+ Use it with the following Hugging Face APIs:
35
+ - [Free Serverless Inference API](https://huggingface.co/inference-api)
36
+ - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
37
+ - [Self-hosted Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference)
38
+
39
+
40
+ ### Usage examples
41
+
42
+ #### With free serverless inference API
43
+
44
+ ```python
45
+ from haystack_integrations.components.embedders.huggingface_api import HuggingFaceAPIDocumentEmbedder
46
+ from haystack.utils import Secret
47
+ from haystack.dataclasses import Document
48
+
49
+ doc = Document(content="I love pizza!")
50
+
51
+ doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="serverless_inference_api",
52
+ api_params={"model": "BAAI/bge-small-en-v1.5"},
53
+ token=Secret.from_token("<your-api-key>"))
54
+
55
+ result = document_embedder.run([doc])
56
+ print(result["documents"][0].embedding)
57
+
58
+ # [0.017020374536514282, -0.023255806416273117, ...]
59
+ ```
60
+
61
+ #### With paid inference endpoints
62
+
63
+ ```python
64
+ from haystack_integrations.components.embedders.huggingface_api import HuggingFaceAPIDocumentEmbedder
65
+ from haystack.utils import Secret
66
+ from haystack.dataclasses import Document
67
+
68
+ doc = Document(content="I love pizza!")
69
+
70
+ doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="inference_endpoints",
71
+ api_params={"url": "<your-inference-endpoint-url>"},
72
+ token=Secret.from_token("<your-api-key>"))
73
+
74
+ result = document_embedder.run([doc])
75
+ print(result["documents"][0].embedding)
76
+
77
+ # [0.017020374536514282, -0.023255806416273117, ...]
78
+ ```
79
+
80
+ #### With self-hosted text embeddings inference
81
+
82
+ ```python
83
+ from haystack_integrations.components.embedders.huggingface_api import HuggingFaceAPIDocumentEmbedder
84
+ from haystack.dataclasses import Document
85
+
86
+ doc = Document(content="I love pizza!")
87
+
88
+ doc_embedder = HuggingFaceAPIDocumentEmbedder(api_type="text_embeddings_inference",
89
+ api_params={"url": "http://localhost:8080"})
90
+
91
+ result = document_embedder.run([doc])
92
+ print(result["documents"][0].embedding)
93
+
94
+ # [0.017020374536514282, -0.023255806416273117, ...]
95
+ ```
96
+ """
97
+
98
+ def __init__(
99
+ self,
100
+ api_type: HFEmbeddingAPIType | str,
101
+ api_params: dict[str, str],
102
+ token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
103
+ prefix: str = "",
104
+ suffix: str = "",
105
+ truncate: bool | None = True,
106
+ normalize: bool | None = False,
107
+ batch_size: int = 32,
108
+ progress_bar: bool = True,
109
+ meta_fields_to_embed: list[str] | None = None,
110
+ embedding_separator: str = "\n",
111
+ concurrency_limit: int = 4,
112
+ ) -> None:
113
+ """
114
+ Creates a HuggingFaceAPIDocumentEmbedder component.
115
+
116
+ :param api_type:
117
+ The type of Hugging Face API to use.
118
+ :param api_params:
119
+ A dictionary with the following keys:
120
+ - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
121
+ - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
122
+ `TEXT_EMBEDDINGS_INFERENCE`.
123
+ :param token: The Hugging Face token to use as HTTP bearer authorization.
124
+ Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
125
+ :param prefix:
126
+ A string to add at the beginning of each text.
127
+ :param suffix:
128
+ A string to add at the end of each text.
129
+ :param truncate:
130
+ Truncates the input text to the maximum length supported by the model.
131
+ Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
132
+ if the backend uses Text Embeddings Inference.
133
+ If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
134
+ :param normalize:
135
+ Normalizes the embeddings to unit length.
136
+ Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
137
+ if the backend uses Text Embeddings Inference.
138
+ If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
139
+ :param batch_size:
140
+ Number of documents to process at once.
141
+ :param progress_bar:
142
+ If `True`, shows a progress bar when running.
143
+ :param meta_fields_to_embed:
144
+ List of metadata fields to embed along with the document text.
145
+ :param embedding_separator:
146
+ Separator used to concatenate the metadata fields to the document text.
147
+ :param concurrency_limit:
148
+ The maximum number of requests that should be allowed to run concurrently.
149
+ This parameter is only used in the `run_async` method.
150
+ """
151
+ if isinstance(api_type, str):
152
+ api_type = HFEmbeddingAPIType.from_str(api_type)
153
+
154
+ api_params = api_params or {}
155
+
156
+ if api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
157
+ model = api_params.get("model")
158
+ if model is None:
159
+ msg = "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
160
+ raise ValueError(msg)
161
+ _check_valid_model(model, HFModelType.EMBEDDING, token)
162
+ model_or_url = model
163
+ elif api_type in [HFEmbeddingAPIType.INFERENCE_ENDPOINTS, HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE]:
164
+ url = api_params.get("url")
165
+ if url is None:
166
+ msg = (
167
+ "To use Text Embeddings Inference or Inference Endpoints, you need to specify the `url` "
168
+ "parameter in `api_params`."
169
+ )
170
+ raise ValueError(msg)
171
+ if not is_valid_http_url(url):
172
+ msg = f"Invalid URL: {url}"
173
+ raise ValueError(msg)
174
+ model_or_url = url
175
+ else:
176
+ msg = f"Unknown api_type {api_type}"
177
+ raise ValueError(msg)
178
+
179
+ client_args: dict[str, Any] = {"model": model_or_url, "token": token.resolve_value() if token else None}
180
+
181
+ self.api_type = api_type
182
+ self.api_params = api_params
183
+ self.token = token
184
+ self.prefix = prefix
185
+ self.suffix = suffix
186
+ self.truncate = truncate
187
+ self.normalize = normalize
188
+ self.batch_size = batch_size
189
+ self.progress_bar = progress_bar
190
+ self.meta_fields_to_embed = meta_fields_to_embed or []
191
+ self.embedding_separator = embedding_separator
192
+ self.concurrency_limit = concurrency_limit
193
+ self._client = InferenceClient(**client_args)
194
+ self._async_client = AsyncInferenceClient(**client_args)
195
+
196
+ def to_dict(self) -> dict[str, Any]:
197
+ """
198
+ Serializes the component to a dictionary.
199
+
200
+ :returns:
201
+ Dictionary with serialized data.
202
+ """
203
+ return default_to_dict(
204
+ self,
205
+ api_type=str(self.api_type),
206
+ api_params=self.api_params,
207
+ prefix=self.prefix,
208
+ suffix=self.suffix,
209
+ token=self.token,
210
+ truncate=self.truncate,
211
+ normalize=self.normalize,
212
+ batch_size=self.batch_size,
213
+ progress_bar=self.progress_bar,
214
+ meta_fields_to_embed=self.meta_fields_to_embed,
215
+ embedding_separator=self.embedding_separator,
216
+ concurrency_limit=self.concurrency_limit,
217
+ )
218
+
219
+ @classmethod
220
+ def from_dict(cls, data: dict[str, Any]) -> "HuggingFaceAPIDocumentEmbedder":
221
+ """
222
+ Deserializes the component from a dictionary.
223
+
224
+ :param data:
225
+ Dictionary to deserialize from.
226
+ :returns:
227
+ Deserialized component.
228
+ """
229
+ return default_from_dict(cls, data)
230
+
231
+ def _prepare_texts_to_embed(self, documents: list[Document]) -> list[str]:
232
+ """
233
+ Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
234
+ """
235
+ texts_to_embed = []
236
+ for doc in documents:
237
+ meta_values_to_embed = [
238
+ str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
239
+ ]
240
+
241
+ text_to_embed = (
242
+ self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix
243
+ )
244
+
245
+ texts_to_embed.append(text_to_embed)
246
+ return texts_to_embed
247
+
248
+ @staticmethod
249
+ def _adjust_api_parameters(
250
+ truncate: bool | None, normalize: bool | None, api_type: HFEmbeddingAPIType
251
+ ) -> tuple[bool | None, bool | None]:
252
+ """
253
+ Adjust the truncate and normalize parameters based on the API type.
254
+ """
255
+ if api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
256
+ if truncate is not None:
257
+ msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored."
258
+ logger.warning(msg)
259
+ truncate = None
260
+ if normalize is not None:
261
+ msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored."
262
+ logger.warning(msg)
263
+ normalize = None
264
+ return truncate, normalize
265
+
266
+ def _embed_batch(self, texts_to_embed: list[str], batch_size: int) -> list[list[float]]:
267
+ """
268
+ Embed a list of texts in batches.
269
+ """
270
+ truncate, normalize = self._adjust_api_parameters(self.truncate, self.normalize, self.api_type)
271
+
272
+ all_embeddings: list = []
273
+ for i in tqdm(
274
+ range(0, len(texts_to_embed), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
275
+ ):
276
+ batch = texts_to_embed[i : i + batch_size]
277
+
278
+ np_embeddings = self._client.feature_extraction(text=batch, truncate=truncate, normalize=normalize)
279
+
280
+ if np_embeddings.ndim != _EXPECTED_EMBEDDING_NDIM or np_embeddings.shape[0] != len(batch):
281
+ msg = f"Expected embedding shape ({batch_size}, embedding_dim), got {np_embeddings.shape}"
282
+ raise ValueError(msg)
283
+
284
+ all_embeddings.extend(np_embeddings.tolist())
285
+
286
+ return all_embeddings
287
+
288
+ async def _embed_batch_async(self, texts_to_embed: list[str], batch_size: int) -> list[list[float]]:
289
+ """
290
+ Embed a list of texts in batches asynchronously.
291
+ """
292
+ truncate, normalize = self._adjust_api_parameters(self.truncate, self.normalize, self.api_type)
293
+ sem = Semaphore(max(1, self.concurrency_limit))
294
+ num_batches = (len(texts_to_embed) + batch_size - 1) // batch_size
295
+ pbar = tqdm(total=num_batches, disable=not self.progress_bar, desc="Calculating embeddings")
296
+
297
+ async def _runner(batch: list[str]) -> list[list[float]]:
298
+ async with sem:
299
+ np_embeddings = await self._async_client.feature_extraction(
300
+ text=batch, truncate=truncate, normalize=normalize
301
+ )
302
+
303
+ if np_embeddings.ndim != _EXPECTED_EMBEDDING_NDIM or np_embeddings.shape[0] != len(batch):
304
+ msg = f"Expected embedding shape ({batch_size}, embedding_dim), got {np_embeddings.shape}"
305
+ raise ValueError(msg)
306
+
307
+ pbar.update(1)
308
+ return np_embeddings.tolist()
309
+
310
+ try:
311
+ all_embeddings = [
312
+ *chain(
313
+ *await gather(
314
+ *[
315
+ _runner(texts_to_embed[i : i + batch_size])
316
+ for i in range(0, len(texts_to_embed), batch_size)
317
+ ]
318
+ )
319
+ )
320
+ ]
321
+ finally:
322
+ pbar.close()
323
+
324
+ return all_embeddings
325
+
326
+ @component.output_types(documents=list[Document])
327
+ def run(self, documents: list[Document]) -> dict[str, list[Document]]:
328
+ """
329
+ Embeds a list of documents.
330
+
331
+ :param documents:
332
+ Documents to embed.
333
+
334
+ :returns:
335
+ A dictionary with the following keys:
336
+ - `documents`: A list of documents with embeddings.
337
+ """
338
+ if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
339
+ msg = (
340
+ "HuggingFaceAPIDocumentEmbedder expects a list of Documents as input."
341
+ " In case you want to embed a string, please use the HuggingFaceAPITextEmbedder."
342
+ )
343
+ raise TypeError(msg)
344
+
345
+ texts_to_embed = self._prepare_texts_to_embed(documents=documents)
346
+
347
+ embeddings = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
348
+
349
+ new_documents = []
350
+ for doc, emb in zip(documents, embeddings, strict=True):
351
+ new_documents.append(replace(doc, embedding=emb))
352
+
353
+ return {"documents": new_documents}
354
+
355
+ @component.output_types(documents=list[Document])
356
+ async def run_async(self, documents: list[Document]) -> dict[str, list[Document]]:
357
+ """
358
+ Embeds a list of documents asynchronously.
359
+
360
+ :param documents:
361
+ Documents to embed.
362
+
363
+ :returns:
364
+ A dictionary with the following keys:
365
+ - `documents`: A list of documents with embeddings.
366
+ """
367
+ if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
368
+ msg = (
369
+ "HuggingFaceAPIDocumentEmbedder expects a list of Documents as input."
370
+ " In case you want to embed a string, please use the HuggingFaceAPITextEmbedder."
371
+ )
372
+ raise TypeError(msg)
373
+
374
+ texts_to_embed = self._prepare_texts_to_embed(documents=documents)
375
+
376
+ embeddings = await self._embed_batch_async(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
377
+
378
+ new_documents = []
379
+ for doc, emb in zip(documents, embeddings, strict=True):
380
+ new_documents.append(replace(doc, embedding=emb))
381
+
382
+ return {"documents": new_documents}