transformers-haystack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. haystack_integrations/components/classifiers/py.typed +0 -0
  2. haystack_integrations/components/classifiers/transformers/__init__.py +6 -0
  3. haystack_integrations/components/classifiers/transformers/zero_shot_document_classifier.py +247 -0
  4. haystack_integrations/components/common/py.typed +0 -0
  5. haystack_integrations/components/common/transformers/__init__.py +3 -0
  6. haystack_integrations/components/common/transformers/utils.py +234 -0
  7. haystack_integrations/components/extractors/py.typed +0 -0
  8. haystack_integrations/components/extractors/transformers/__init__.py +6 -0
  9. haystack_integrations/components/extractors/transformers/named_entity_extractor.py +262 -0
  10. haystack_integrations/components/generators/py.typed +0 -0
  11. haystack_integrations/components/generators/transformers/__init__.py +6 -0
  12. haystack_integrations/components/generators/transformers/chat/__init__.py +3 -0
  13. haystack_integrations/components/generators/transformers/chat/chat_generator.py +666 -0
  14. haystack_integrations/components/readers/py.typed +0 -0
  15. haystack_integrations/components/readers/transformers/__init__.py +6 -0
  16. haystack_integrations/components/readers/transformers/extractive_reader.py +662 -0
  17. haystack_integrations/components/routers/py.typed +0 -0
  18. haystack_integrations/components/routers/transformers/__init__.py +7 -0
  19. haystack_integrations/components/routers/transformers/text_router.py +196 -0
  20. haystack_integrations/components/routers/transformers/zero_shot_text_router.py +205 -0
  21. transformers_haystack-0.1.0.dist-info/METADATA +38 -0
  22. transformers_haystack-0.1.0.dist-info/RECORD +24 -0
  23. transformers_haystack-0.1.0.dist-info/WHEEL +4 -0
  24. transformers_haystack-0.1.0.dist-info/licenses/LICENSE.txt +201 -0
@@ -0,0 +1,262 @@
1
+ # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ from dataclasses import dataclass, replace
6
+ from typing import Any
7
+
8
+ from haystack import ComponentError, DeserializationError, Document, component, default_from_dict, default_to_dict
9
+ from haystack.utils.auth import Secret
10
+ from haystack.utils.device import ComponentDevice
11
+ from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
12
+
13
+ from haystack_integrations.components.common.transformers.utils import _resolve_hf_pipeline_kwargs
14
+ from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
15
+ from transformers import Pipeline as HfPipeline
16
+
17
+
18
+ @dataclass
19
+ class NamedEntityAnnotation:
20
+ """
21
+ Describes a single NER annotation.
22
+
23
+ :param entity:
24
+ Entity label.
25
+ :param start:
26
+ Start index of the entity in the document.
27
+ :param end:
28
+ End index of the entity in the document.
29
+ :param score:
30
+ Score calculated by the model.
31
+ """
32
+
33
+ entity: str
34
+ start: int
35
+ end: int
36
+ score: float | None = None
37
+
38
+
39
+ @component
40
+ class TransformersNamedEntityExtractor:
41
+ """
42
+ Annotates named entities in a collection of documents.
43
+
44
+ The component can be used with any token classification model from the
45
+ [Hugging Face model hub](https://huggingface.co/models). Annotations are
46
+ stored as metadata in the documents.
47
+
48
+ Usage example:
49
+ ```python
50
+ from haystack import Document
51
+
52
+ from haystack_integrations.components.extractors.transformers import TransformersNamedEntityExtractor
53
+
54
+ documents = [
55
+ Document(content="I'm Merlin, the happy pig!"),
56
+ Document(content="My name is Clara and I live in Berkeley, California."),
57
+ ]
58
+ extractor = TransformersNamedEntityExtractor(model="dslim/bert-base-NER")
59
+ results = extractor.run(documents=documents)["documents"]
60
+ annotations = [TransformersNamedEntityExtractor.get_stored_annotations(doc) for doc in results]
61
+ print(annotations)
62
+ ```
63
+ """
64
+
65
+ _METADATA_KEY = "named_entities"
66
+
67
+ def __init__(
68
+ self,
69
+ *,
70
+ model: str,
71
+ pipeline_kwargs: dict[str, Any] | None = None,
72
+ device: ComponentDevice | None = None,
73
+ token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
74
+ ) -> None:
75
+ """
76
+ Create a Named Entity extractor component.
77
+
78
+ :param model:
79
+ Name of the model or a path to the model on
80
+ the local disk.
81
+ :param pipeline_kwargs:
82
+ Keyword arguments passed to the pipeline. The
83
+ pipeline can override these arguments.
84
+ :param device:
85
+ The device on which the model is loaded. If `None`,
86
+ the default device is automatically selected. If a
87
+ device/device map is specified in `pipeline_kwargs`,
88
+ it overrides this parameter.
89
+ :param token:
90
+ The API token to download private models from Hugging Face.
91
+ """
92
+ self.token = token
93
+ self.model_name_or_path = model
94
+ self.device = ComponentDevice.resolve_device(device)
95
+
96
+ self.pipeline_kwargs = _resolve_hf_pipeline_kwargs(
97
+ huggingface_pipeline_kwargs=pipeline_kwargs or {},
98
+ model=model,
99
+ task="ner",
100
+ supported_tasks=["ner"],
101
+ device=self.device,
102
+ token=token,
103
+ )
104
+
105
+ self.tokenizer: Any = None
106
+ self.model: AutoModelForTokenClassification | None = None
107
+ self.pipeline: HfPipeline | None = None
108
+ self._warmed_up: bool = False
109
+
110
+ def warm_up(self) -> None:
111
+ """
112
+ Initialize the component.
113
+
114
+ :raises ComponentError:
115
+ If the component fails to initialize successfully.
116
+ """
117
+ if self._warmed_up:
118
+ return
119
+
120
+ try:
121
+ token = self.pipeline_kwargs.get("token", None)
122
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, token=token)
123
+ self.model = AutoModelForTokenClassification.from_pretrained(self.model_name_or_path, token=token)
124
+
125
+ pipeline_params: dict[str, Any] = {
126
+ "task": "ner",
127
+ "model": self.model,
128
+ "tokenizer": self.tokenizer,
129
+ "aggregation_strategy": "simple",
130
+ }
131
+ pipeline_params.update({k: v for k, v in self.pipeline_kwargs.items() if k not in pipeline_params})
132
+ self.device.update_hf_kwargs(pipeline_params, overwrite=False)
133
+ self.pipeline = pipeline(**pipeline_params)
134
+ self._warmed_up = True
135
+ except Exception as e:
136
+ msg = f"{self.__class__.__name__} failed to initialize."
137
+ raise ComponentError(msg) from e
138
+
139
+ @component.output_types(documents=list[Document])
140
+ def run(self, documents: list[Document], batch_size: int = 1) -> dict[str, Any]:
141
+ """
142
+ Annotate named entities in each document and store the annotations in the document's metadata.
143
+
144
+ :param documents:
145
+ Documents to process.
146
+ :param batch_size:
147
+ Batch size used for processing the documents.
148
+ :returns:
149
+ Processed documents.
150
+ :raises ComponentError:
151
+ If the model fails to process a document.
152
+ """
153
+ if not self._warmed_up:
154
+ self.warm_up()
155
+
156
+ texts = [doc.content if doc.content is not None else "" for doc in documents]
157
+ annotations = self._annotate(texts, batch_size=batch_size)
158
+
159
+ if len(annotations) != len(documents):
160
+ msg = (
161
+ "NER model did not return the correct number of annotations; "
162
+ f"got {len(annotations)} but expected {len(documents)}"
163
+ )
164
+ raise ComponentError(msg)
165
+
166
+ new_documents = []
167
+ for doc, doc_annotations in zip(documents, annotations, strict=True):
168
+ new_meta = {**doc.meta, self._METADATA_KEY: doc_annotations}
169
+ new_documents.append(replace(doc, meta=new_meta))
170
+
171
+ return {"documents": new_documents}
172
+
173
+ def _annotate(self, texts: list[str], *, batch_size: int = 1) -> list[list[NamedEntityAnnotation]]:
174
+ """
175
+ Predict annotations for a collection of documents.
176
+
177
+ :param texts:
178
+ Raw texts to be annotated.
179
+ :param batch_size:
180
+ Size of text batches that are
181
+ passed to the model.
182
+ :returns:
183
+ NER annotations.
184
+ """
185
+ if not self.initialized:
186
+ msg = "NER model was not initialized - Did you call `warm_up()`?"
187
+ raise ComponentError(msg)
188
+
189
+ assert self.pipeline is not None # noqa: S101
190
+ outputs = self.pipeline(texts, batch_size=batch_size)
191
+ return [
192
+ [
193
+ NamedEntityAnnotation(
194
+ entity=annotation["entity"] if "entity" in annotation else annotation["entity_group"],
195
+ start=annotation["start"],
196
+ end=annotation["end"],
197
+ score=annotation["score"],
198
+ )
199
+ for annotation in annotations
200
+ ]
201
+ for annotations in outputs
202
+ ]
203
+
204
+ def to_dict(self) -> dict[str, Any]:
205
+ """
206
+ Serializes the component to a dictionary.
207
+
208
+ :returns:
209
+ Dictionary with serialized data.
210
+ """
211
+ serialization_dict = default_to_dict(
212
+ self,
213
+ model=self.model_name_or_path,
214
+ device=self.device,
215
+ pipeline_kwargs=self.pipeline_kwargs,
216
+ token=self.token,
217
+ )
218
+
219
+ hf_pipeline_kwargs = serialization_dict["init_parameters"]["pipeline_kwargs"]
220
+ hf_pipeline_kwargs.pop("token", None)
221
+
222
+ serialize_hf_model_kwargs(hf_pipeline_kwargs)
223
+ return serialization_dict
224
+
225
+ @classmethod
226
+ def from_dict(cls, data: dict[str, Any]) -> "TransformersNamedEntityExtractor":
227
+ """
228
+ Deserializes the component from a dictionary.
229
+
230
+ :param data:
231
+ Dictionary to deserialize from.
232
+ :returns:
233
+ Deserialized component.
234
+ """
235
+ try:
236
+ init_params = data.get("init_parameters", {})
237
+ hf_pipeline_kwargs = init_params.get("pipeline_kwargs")
238
+ deserialize_hf_model_kwargs(hf_pipeline_kwargs or {})
239
+ return default_from_dict(cls, data)
240
+ except Exception as e:
241
+ msg = f"Couldn't deserialize {cls.__name__} instance"
242
+ raise DeserializationError(msg) from e
243
+
244
+ @property
245
+ def initialized(self) -> bool:
246
+ """
247
+ Returns if the extractor is ready to annotate text.
248
+ """
249
+ return (self.tokenizer is not None and self.model is not None) or self.pipeline is not None
250
+
251
+ @classmethod
252
+ def get_stored_annotations(cls, document: Document) -> list[NamedEntityAnnotation] | None:
253
+ """
254
+ Returns the document's named entity annotations stored in its metadata, if any.
255
+
256
+ :param document:
257
+ Document whose annotations are to be fetched.
258
+ :returns:
259
+ The stored annotations.
260
+ """
261
+
262
+ return document.meta.get(cls._METADATA_KEY)
File without changes
@@ -0,0 +1,6 @@
1
+ # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ from .chat.chat_generator import TransformersChatGenerator
5
+
6
+ __all__ = ["TransformersChatGenerator"]
@@ -0,0 +1,3 @@
1
+ # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0