huggingface-api-haystack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,262 @@
1
+ # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ from typing import Any
6
+
7
+ from haystack import component, default_from_dict, default_to_dict, logging
8
+ from haystack.utils import Secret
9
+ from haystack.utils.url_validation import is_valid_http_url
10
+ from huggingface_hub import AsyncInferenceClient, InferenceClient
11
+
12
+ from haystack_integrations.components.common.huggingface_api.utils import (
13
+ HFEmbeddingAPIType,
14
+ HFModelType,
15
+ _check_valid_model,
16
+ )
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # embeddings returned by the API can have shape (embedding_dim,) or (1, embedding_dim)
21
+ _MAX_EMBEDDING_NDIM = 2
22
+
23
+
24
+ @component
25
+ class HuggingFaceAPITextEmbedder:
26
+ """
27
+ Embeds strings using Hugging Face APIs.
28
+
29
+ Use it with the following Hugging Face APIs:
30
+ - [Free Serverless Inference API](https://huggingface.co/inference-api)
31
+ - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
32
+ - [Self-hosted Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference)
33
+
34
+ ### Usage examples
35
+
36
+ #### With free serverless inference API
37
+
38
+ ```python
39
+ from haystack_integrations.components.embedders.huggingface_api import HuggingFaceAPITextEmbedder
40
+ from haystack.utils import Secret
41
+
42
+ text_embedder = HuggingFaceAPITextEmbedder(api_type="serverless_inference_api",
43
+ api_params={"model": "BAAI/bge-small-en-v1.5"},
44
+ token=Secret.from_token("<your-api-key>"))
45
+
46
+ print(text_embedder.run("I love pizza!"))
47
+
48
+ # {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
49
+ ```
50
+
51
+ #### With paid inference endpoints
52
+
53
+ ```python
54
+ from haystack_integrations.components.embedders.huggingface_api import HuggingFaceAPITextEmbedder
55
+ from haystack.utils import Secret
56
+ text_embedder = HuggingFaceAPITextEmbedder(api_type="inference_endpoints",
57
+ api_params={"model": "BAAI/bge-small-en-v1.5"},
58
+ token=Secret.from_token("<your-api-key>"))
59
+
60
+ print(text_embedder.run("I love pizza!"))
61
+
62
+ # {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
63
+ ```
64
+
65
+ #### With self-hosted text embeddings inference
66
+
67
+ ```python
68
+ from haystack_integrations.components.embedders.huggingface_api import HuggingFaceAPITextEmbedder
69
+ from haystack.utils import Secret
70
+
71
+ text_embedder = HuggingFaceAPITextEmbedder(api_type="text_embeddings_inference",
72
+ api_params={"url": "http://localhost:8080"})
73
+
74
+ print(text_embedder.run("I love pizza!"))
75
+
76
+ # {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
77
+ ```
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ api_type: HFEmbeddingAPIType | str,
83
+ api_params: dict[str, str],
84
+ token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
85
+ prefix: str = "",
86
+ suffix: str = "",
87
+ truncate: bool | None = True,
88
+ normalize: bool | None = False,
89
+ ) -> None:
90
+ """
91
+ Creates a HuggingFaceAPITextEmbedder component.
92
+
93
+ :param api_type:
94
+ The type of Hugging Face API to use.
95
+ :param api_params:
96
+ A dictionary with the following keys:
97
+ - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
98
+ - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
99
+ `TEXT_EMBEDDINGS_INFERENCE`.
100
+ :param token: The Hugging Face token to use as HTTP bearer authorization.
101
+ Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
102
+ :param prefix:
103
+ A string to add at the beginning of each text.
104
+ :param suffix:
105
+ A string to add at the end of each text.
106
+ :param truncate:
107
+ Truncates the input text to the maximum length supported by the model.
108
+ Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
109
+ if the backend uses Text Embeddings Inference.
110
+ If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
111
+ :param normalize:
112
+ Normalizes the embeddings to unit length.
113
+ Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
114
+ if the backend uses Text Embeddings Inference.
115
+ If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
116
+ """
117
+ if isinstance(api_type, str):
118
+ api_type = HFEmbeddingAPIType.from_str(api_type)
119
+
120
+ if api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
121
+ model = api_params.get("model")
122
+ if model is None:
123
+ msg = "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
124
+ raise ValueError(msg)
125
+ _check_valid_model(model, HFModelType.EMBEDDING, token)
126
+ model_or_url = model
127
+ elif api_type in [HFEmbeddingAPIType.INFERENCE_ENDPOINTS, HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE]:
128
+ url = api_params.get("url")
129
+ if url is None:
130
+ msg = (
131
+ "To use Text Embeddings Inference or Inference Endpoints, you need to specify the `url` "
132
+ "parameter in `api_params`."
133
+ )
134
+ raise ValueError(msg)
135
+ if not is_valid_http_url(url):
136
+ msg = f"Invalid URL: {url}"
137
+ raise ValueError(msg)
138
+ model_or_url = url
139
+ else:
140
+ msg = f"Unknown api_type {api_type}"
141
+ raise ValueError(msg)
142
+
143
+ self.api_type = api_type
144
+ self.api_params = api_params
145
+ self.token = token
146
+ self.prefix = prefix
147
+ self.suffix = suffix
148
+ self.truncate = truncate
149
+ self.normalize = normalize
150
+ self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
151
+ self._async_client = AsyncInferenceClient(model_or_url, token=token.resolve_value() if token else None)
152
+
153
+ def _prepare_input(self, text: str) -> tuple[str, bool | None, bool | None]:
154
+ if not isinstance(text, str):
155
+ msg = (
156
+ "HuggingFaceAPITextEmbedder expects a string as an input."
157
+ "In case you want to embed a list of Documents, please use the HuggingFaceAPIDocumentEmbedder."
158
+ )
159
+ raise TypeError(msg)
160
+
161
+ truncate = self.truncate
162
+ normalize = self.normalize
163
+
164
+ if self.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
165
+ if truncate is not None:
166
+ msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored."
167
+ logger.warning(msg)
168
+ truncate = None
169
+ if normalize is not None:
170
+ msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored."
171
+ logger.warning(msg)
172
+ normalize = None
173
+
174
+ text_to_embed = self.prefix + text + self.suffix
175
+
176
+ return text_to_embed, truncate, normalize
177
+
178
+ def to_dict(self) -> dict[str, Any]:
179
+ """
180
+ Serializes the component to a dictionary.
181
+
182
+ :returns:
183
+ Dictionary with serialized data.
184
+ """
185
+ return default_to_dict(
186
+ self,
187
+ api_type=str(self.api_type),
188
+ api_params=self.api_params,
189
+ prefix=self.prefix,
190
+ suffix=self.suffix,
191
+ token=self.token,
192
+ truncate=self.truncate,
193
+ normalize=self.normalize,
194
+ )
195
+
196
+ @classmethod
197
+ def from_dict(cls, data: dict[str, Any]) -> "HuggingFaceAPITextEmbedder":
198
+ """
199
+ Deserializes the component from a dictionary.
200
+
201
+ :param data:
202
+ Dictionary to deserialize from.
203
+ :returns:
204
+ Deserialized component.
205
+ """
206
+ return default_from_dict(cls, data)
207
+
208
+ @component.output_types(embedding=list[float])
209
+ def run(self, text: str) -> dict[str, Any]:
210
+ """
211
+ Embeds a single string.
212
+
213
+ :param text:
214
+ Text to embed.
215
+
216
+ :returns:
217
+ A dictionary with the following keys:
218
+ - `embedding`: The embedding of the input text.
219
+ """
220
+ text_to_embed, truncate_val, normalize_val = self._prepare_input(text)
221
+
222
+ np_embedding = self._client.feature_extraction(
223
+ text=text_to_embed, truncate=truncate_val, normalize=normalize_val
224
+ )
225
+
226
+ error_msg = f"Expected embedding shape (1, embedding_dim) or (embedding_dim,), got {np_embedding.shape}"
227
+ if np_embedding.ndim > _MAX_EMBEDDING_NDIM:
228
+ raise ValueError(error_msg)
229
+ if np_embedding.ndim == _MAX_EMBEDDING_NDIM and np_embedding.shape[0] != 1:
230
+ raise ValueError(error_msg)
231
+
232
+ embedding = np_embedding.flatten().tolist()
233
+
234
+ return {"embedding": embedding}
235
+
236
+ @component.output_types(embedding=list[float])
237
+ async def run_async(self, text: str) -> dict[str, Any]:
238
+ """
239
+ Embeds a single string asynchronously.
240
+
241
+ :param text:
242
+ Text to embed.
243
+
244
+ :returns:
245
+ A dictionary with the following keys:
246
+ - `embedding`: The embedding of the input text.
247
+ """
248
+ text_to_embed, truncate_val, normalize_val = self._prepare_input(text)
249
+
250
+ np_embedding = await self._async_client.feature_extraction(
251
+ text=text_to_embed, truncate=truncate_val, normalize=normalize_val
252
+ )
253
+
254
+ error_msg = f"Expected embedding shape (1, embedding_dim) or (embedding_dim,), got {np_embedding.shape}"
255
+ if np_embedding.ndim > _MAX_EMBEDDING_NDIM:
256
+ raise ValueError(error_msg)
257
+ if np_embedding.ndim == _MAX_EMBEDDING_NDIM and np_embedding.shape[0] != 1:
258
+ raise ValueError(error_msg)
259
+
260
+ embedding = np_embedding.flatten().tolist()
261
+
262
+ return {"embedding": embedding}
File without changes
@@ -0,0 +1,6 @@
1
+ # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ from .chat.chat_generator import HuggingFaceAPIChatGenerator
5
+
6
+ __all__ = ["HuggingFaceAPIChatGenerator"]
@@ -0,0 +1,3 @@
1
+ # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0