huggingface-api-haystack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- haystack_integrations/components/common/huggingface_api/__init__.py +3 -0
- haystack_integrations/components/common/huggingface_api/utils.py +112 -0
- haystack_integrations/components/common/py.typed +0 -0
- haystack_integrations/components/embedders/huggingface_api/__init__.py +7 -0
- haystack_integrations/components/embedders/huggingface_api/document_embedder.py +382 -0
- haystack_integrations/components/embedders/huggingface_api/text_embedder.py +262 -0
- haystack_integrations/components/embedders/py.typed +0 -0
- haystack_integrations/components/generators/huggingface_api/__init__.py +6 -0
- haystack_integrations/components/generators/huggingface_api/chat/__init__.py +3 -0
- haystack_integrations/components/generators/huggingface_api/chat/chat_generator.py +738 -0
- haystack_integrations/components/generators/py.typed +0 -0
- huggingface_api_haystack-0.1.0.dist-info/METADATA +40 -0
- huggingface_api_haystack-0.1.0.dist-info/RECORD +15 -0
- huggingface_api_haystack-0.1.0.dist-info/WHEEL +4 -0
- huggingface_api_haystack-0.1.0.dist-info/licenses/LICENSE.txt +201 -0
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from haystack import component, default_from_dict, default_to_dict, logging
|
|
8
|
+
from haystack.utils import Secret
|
|
9
|
+
from haystack.utils.url_validation import is_valid_http_url
|
|
10
|
+
from huggingface_hub import AsyncInferenceClient, InferenceClient
|
|
11
|
+
|
|
12
|
+
from haystack_integrations.components.common.huggingface_api.utils import (
|
|
13
|
+
HFEmbeddingAPIType,
|
|
14
|
+
HFModelType,
|
|
15
|
+
_check_valid_model,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
# embeddings returned by the API can have shape (embedding_dim,) or (1, embedding_dim)
|
|
21
|
+
_MAX_EMBEDDING_NDIM = 2
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@component
|
|
25
|
+
class HuggingFaceAPITextEmbedder:
|
|
26
|
+
"""
|
|
27
|
+
Embeds strings using Hugging Face APIs.
|
|
28
|
+
|
|
29
|
+
Use it with the following Hugging Face APIs:
|
|
30
|
+
- [Free Serverless Inference API](https://huggingface.co/inference-api)
|
|
31
|
+
- [Paid Inference Endpoints](https://huggingface.co/inference-endpoints)
|
|
32
|
+
- [Self-hosted Text Embeddings Inference](https://github.com/huggingface/text-embeddings-inference)
|
|
33
|
+
|
|
34
|
+
### Usage examples
|
|
35
|
+
|
|
36
|
+
#### With free serverless inference API
|
|
37
|
+
|
|
38
|
+
```python
|
|
39
|
+
from haystack_integrations.components.embedders.huggingface_api import HuggingFaceAPITextEmbedder
|
|
40
|
+
from haystack.utils import Secret
|
|
41
|
+
|
|
42
|
+
text_embedder = HuggingFaceAPITextEmbedder(api_type="serverless_inference_api",
|
|
43
|
+
api_params={"model": "BAAI/bge-small-en-v1.5"},
|
|
44
|
+
token=Secret.from_token("<your-api-key>"))
|
|
45
|
+
|
|
46
|
+
print(text_embedder.run("I love pizza!"))
|
|
47
|
+
|
|
48
|
+
# {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
#### With paid inference endpoints
|
|
52
|
+
|
|
53
|
+
```python
|
|
54
|
+
from haystack_integrations.components.embedders.huggingface_api import HuggingFaceAPITextEmbedder
|
|
55
|
+
from haystack.utils import Secret
|
|
56
|
+
text_embedder = HuggingFaceAPITextEmbedder(api_type="inference_endpoints",
|
|
57
|
+
api_params={"model": "BAAI/bge-small-en-v1.5"},
|
|
58
|
+
token=Secret.from_token("<your-api-key>"))
|
|
59
|
+
|
|
60
|
+
print(text_embedder.run("I love pizza!"))
|
|
61
|
+
|
|
62
|
+
# {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
|
|
63
|
+
```
|
|
64
|
+
|
|
65
|
+
#### With self-hosted text embeddings inference
|
|
66
|
+
|
|
67
|
+
```python
|
|
68
|
+
from haystack_integrations.components.embedders.huggingface_api import HuggingFaceAPITextEmbedder
|
|
69
|
+
from haystack.utils import Secret
|
|
70
|
+
|
|
71
|
+
text_embedder = HuggingFaceAPITextEmbedder(api_type="text_embeddings_inference",
|
|
72
|
+
api_params={"url": "http://localhost:8080"})
|
|
73
|
+
|
|
74
|
+
print(text_embedder.run("I love pizza!"))
|
|
75
|
+
|
|
76
|
+
# {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
|
|
77
|
+
```
|
|
78
|
+
"""
|
|
79
|
+
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
api_type: HFEmbeddingAPIType | str,
|
|
83
|
+
api_params: dict[str, str],
|
|
84
|
+
token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
|
|
85
|
+
prefix: str = "",
|
|
86
|
+
suffix: str = "",
|
|
87
|
+
truncate: bool | None = True,
|
|
88
|
+
normalize: bool | None = False,
|
|
89
|
+
) -> None:
|
|
90
|
+
"""
|
|
91
|
+
Creates a HuggingFaceAPITextEmbedder component.
|
|
92
|
+
|
|
93
|
+
:param api_type:
|
|
94
|
+
The type of Hugging Face API to use.
|
|
95
|
+
:param api_params:
|
|
96
|
+
A dictionary with the following keys:
|
|
97
|
+
- `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`.
|
|
98
|
+
- `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or
|
|
99
|
+
`TEXT_EMBEDDINGS_INFERENCE`.
|
|
100
|
+
:param token: The Hugging Face token to use as HTTP bearer authorization.
|
|
101
|
+
Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
|
|
102
|
+
:param prefix:
|
|
103
|
+
A string to add at the beginning of each text.
|
|
104
|
+
:param suffix:
|
|
105
|
+
A string to add at the end of each text.
|
|
106
|
+
:param truncate:
|
|
107
|
+
Truncates the input text to the maximum length supported by the model.
|
|
108
|
+
Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
|
|
109
|
+
if the backend uses Text Embeddings Inference.
|
|
110
|
+
If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
|
|
111
|
+
:param normalize:
|
|
112
|
+
Normalizes the embeddings to unit length.
|
|
113
|
+
Applicable when `api_type` is `TEXT_EMBEDDINGS_INFERENCE`, or `INFERENCE_ENDPOINTS`
|
|
114
|
+
if the backend uses Text Embeddings Inference.
|
|
115
|
+
If `api_type` is `SERVERLESS_INFERENCE_API`, this parameter is ignored.
|
|
116
|
+
"""
|
|
117
|
+
if isinstance(api_type, str):
|
|
118
|
+
api_type = HFEmbeddingAPIType.from_str(api_type)
|
|
119
|
+
|
|
120
|
+
if api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
|
|
121
|
+
model = api_params.get("model")
|
|
122
|
+
if model is None:
|
|
123
|
+
msg = "To use the Serverless Inference API, you need to specify the `model` parameter in `api_params`."
|
|
124
|
+
raise ValueError(msg)
|
|
125
|
+
_check_valid_model(model, HFModelType.EMBEDDING, token)
|
|
126
|
+
model_or_url = model
|
|
127
|
+
elif api_type in [HFEmbeddingAPIType.INFERENCE_ENDPOINTS, HFEmbeddingAPIType.TEXT_EMBEDDINGS_INFERENCE]:
|
|
128
|
+
url = api_params.get("url")
|
|
129
|
+
if url is None:
|
|
130
|
+
msg = (
|
|
131
|
+
"To use Text Embeddings Inference or Inference Endpoints, you need to specify the `url` "
|
|
132
|
+
"parameter in `api_params`."
|
|
133
|
+
)
|
|
134
|
+
raise ValueError(msg)
|
|
135
|
+
if not is_valid_http_url(url):
|
|
136
|
+
msg = f"Invalid URL: {url}"
|
|
137
|
+
raise ValueError(msg)
|
|
138
|
+
model_or_url = url
|
|
139
|
+
else:
|
|
140
|
+
msg = f"Unknown api_type {api_type}"
|
|
141
|
+
raise ValueError(msg)
|
|
142
|
+
|
|
143
|
+
self.api_type = api_type
|
|
144
|
+
self.api_params = api_params
|
|
145
|
+
self.token = token
|
|
146
|
+
self.prefix = prefix
|
|
147
|
+
self.suffix = suffix
|
|
148
|
+
self.truncate = truncate
|
|
149
|
+
self.normalize = normalize
|
|
150
|
+
self._client = InferenceClient(model_or_url, token=token.resolve_value() if token else None)
|
|
151
|
+
self._async_client = AsyncInferenceClient(model_or_url, token=token.resolve_value() if token else None)
|
|
152
|
+
|
|
153
|
+
def _prepare_input(self, text: str) -> tuple[str, bool | None, bool | None]:
|
|
154
|
+
if not isinstance(text, str):
|
|
155
|
+
msg = (
|
|
156
|
+
"HuggingFaceAPITextEmbedder expects a string as an input."
|
|
157
|
+
"In case you want to embed a list of Documents, please use the HuggingFaceAPIDocumentEmbedder."
|
|
158
|
+
)
|
|
159
|
+
raise TypeError(msg)
|
|
160
|
+
|
|
161
|
+
truncate = self.truncate
|
|
162
|
+
normalize = self.normalize
|
|
163
|
+
|
|
164
|
+
if self.api_type == HFEmbeddingAPIType.SERVERLESS_INFERENCE_API:
|
|
165
|
+
if truncate is not None:
|
|
166
|
+
msg = "`truncate` parameter is not supported for Serverless Inference API. It will be ignored."
|
|
167
|
+
logger.warning(msg)
|
|
168
|
+
truncate = None
|
|
169
|
+
if normalize is not None:
|
|
170
|
+
msg = "`normalize` parameter is not supported for Serverless Inference API. It will be ignored."
|
|
171
|
+
logger.warning(msg)
|
|
172
|
+
normalize = None
|
|
173
|
+
|
|
174
|
+
text_to_embed = self.prefix + text + self.suffix
|
|
175
|
+
|
|
176
|
+
return text_to_embed, truncate, normalize
|
|
177
|
+
|
|
178
|
+
def to_dict(self) -> dict[str, Any]:
|
|
179
|
+
"""
|
|
180
|
+
Serializes the component to a dictionary.
|
|
181
|
+
|
|
182
|
+
:returns:
|
|
183
|
+
Dictionary with serialized data.
|
|
184
|
+
"""
|
|
185
|
+
return default_to_dict(
|
|
186
|
+
self,
|
|
187
|
+
api_type=str(self.api_type),
|
|
188
|
+
api_params=self.api_params,
|
|
189
|
+
prefix=self.prefix,
|
|
190
|
+
suffix=self.suffix,
|
|
191
|
+
token=self.token,
|
|
192
|
+
truncate=self.truncate,
|
|
193
|
+
normalize=self.normalize,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
@classmethod
|
|
197
|
+
def from_dict(cls, data: dict[str, Any]) -> "HuggingFaceAPITextEmbedder":
|
|
198
|
+
"""
|
|
199
|
+
Deserializes the component from a dictionary.
|
|
200
|
+
|
|
201
|
+
:param data:
|
|
202
|
+
Dictionary to deserialize from.
|
|
203
|
+
:returns:
|
|
204
|
+
Deserialized component.
|
|
205
|
+
"""
|
|
206
|
+
return default_from_dict(cls, data)
|
|
207
|
+
|
|
208
|
+
@component.output_types(embedding=list[float])
|
|
209
|
+
def run(self, text: str) -> dict[str, Any]:
|
|
210
|
+
"""
|
|
211
|
+
Embeds a single string.
|
|
212
|
+
|
|
213
|
+
:param text:
|
|
214
|
+
Text to embed.
|
|
215
|
+
|
|
216
|
+
:returns:
|
|
217
|
+
A dictionary with the following keys:
|
|
218
|
+
- `embedding`: The embedding of the input text.
|
|
219
|
+
"""
|
|
220
|
+
text_to_embed, truncate_val, normalize_val = self._prepare_input(text)
|
|
221
|
+
|
|
222
|
+
np_embedding = self._client.feature_extraction(
|
|
223
|
+
text=text_to_embed, truncate=truncate_val, normalize=normalize_val
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
error_msg = f"Expected embedding shape (1, embedding_dim) or (embedding_dim,), got {np_embedding.shape}"
|
|
227
|
+
if np_embedding.ndim > _MAX_EMBEDDING_NDIM:
|
|
228
|
+
raise ValueError(error_msg)
|
|
229
|
+
if np_embedding.ndim == _MAX_EMBEDDING_NDIM and np_embedding.shape[0] != 1:
|
|
230
|
+
raise ValueError(error_msg)
|
|
231
|
+
|
|
232
|
+
embedding = np_embedding.flatten().tolist()
|
|
233
|
+
|
|
234
|
+
return {"embedding": embedding}
|
|
235
|
+
|
|
236
|
+
@component.output_types(embedding=list[float])
|
|
237
|
+
async def run_async(self, text: str) -> dict[str, Any]:
|
|
238
|
+
"""
|
|
239
|
+
Embeds a single string asynchronously.
|
|
240
|
+
|
|
241
|
+
:param text:
|
|
242
|
+
Text to embed.
|
|
243
|
+
|
|
244
|
+
:returns:
|
|
245
|
+
A dictionary with the following keys:
|
|
246
|
+
- `embedding`: The embedding of the input text.
|
|
247
|
+
"""
|
|
248
|
+
text_to_embed, truncate_val, normalize_val = self._prepare_input(text)
|
|
249
|
+
|
|
250
|
+
np_embedding = await self._async_client.feature_extraction(
|
|
251
|
+
text=text_to_embed, truncate=truncate_val, normalize=normalize_val
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
error_msg = f"Expected embedding shape (1, embedding_dim) or (embedding_dim,), got {np_embedding.shape}"
|
|
255
|
+
if np_embedding.ndim > _MAX_EMBEDDING_NDIM:
|
|
256
|
+
raise ValueError(error_msg)
|
|
257
|
+
if np_embedding.ndim == _MAX_EMBEDDING_NDIM and np_embedding.shape[0] != 1:
|
|
258
|
+
raise ValueError(error_msg)
|
|
259
|
+
|
|
260
|
+
embedding = np_embedding.flatten().tolist()
|
|
261
|
+
|
|
262
|
+
return {"embedding": embedding}
|
|
File without changes
|