transformers-haystack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- haystack_integrations/components/classifiers/py.typed +0 -0
- haystack_integrations/components/classifiers/transformers/__init__.py +6 -0
- haystack_integrations/components/classifiers/transformers/zero_shot_document_classifier.py +247 -0
- haystack_integrations/components/common/py.typed +0 -0
- haystack_integrations/components/common/transformers/__init__.py +3 -0
- haystack_integrations/components/common/transformers/utils.py +234 -0
- haystack_integrations/components/extractors/py.typed +0 -0
- haystack_integrations/components/extractors/transformers/__init__.py +6 -0
- haystack_integrations/components/extractors/transformers/named_entity_extractor.py +262 -0
- haystack_integrations/components/generators/py.typed +0 -0
- haystack_integrations/components/generators/transformers/__init__.py +6 -0
- haystack_integrations/components/generators/transformers/chat/__init__.py +3 -0
- haystack_integrations/components/generators/transformers/chat/chat_generator.py +666 -0
- haystack_integrations/components/readers/py.typed +0 -0
- haystack_integrations/components/readers/transformers/__init__.py +6 -0
- haystack_integrations/components/readers/transformers/extractive_reader.py +662 -0
- haystack_integrations/components/routers/py.typed +0 -0
- haystack_integrations/components/routers/transformers/__init__.py +7 -0
- haystack_integrations/components/routers/transformers/text_router.py +196 -0
- haystack_integrations/components/routers/transformers/zero_shot_text_router.py +205 -0
- transformers_haystack-0.1.0.dist-info/METADATA +38 -0
- transformers_haystack-0.1.0.dist-info/RECORD +24 -0
- transformers_haystack-0.1.0.dist-info/WHEEL +4 -0
- transformers_haystack-0.1.0.dist-info/licenses/LICENSE.txt +201 -0
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, replace
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from haystack import ComponentError, DeserializationError, Document, component, default_from_dict, default_to_dict
|
|
9
|
+
from haystack.utils.auth import Secret
|
|
10
|
+
from haystack.utils.device import ComponentDevice
|
|
11
|
+
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
|
|
12
|
+
|
|
13
|
+
from haystack_integrations.components.common.transformers.utils import _resolve_hf_pipeline_kwargs
|
|
14
|
+
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
|
|
15
|
+
from transformers import Pipeline as HfPipeline
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class NamedEntityAnnotation:
|
|
20
|
+
"""
|
|
21
|
+
Describes a single NER annotation.
|
|
22
|
+
|
|
23
|
+
:param entity:
|
|
24
|
+
Entity label.
|
|
25
|
+
:param start:
|
|
26
|
+
Start index of the entity in the document.
|
|
27
|
+
:param end:
|
|
28
|
+
End index of the entity in the document.
|
|
29
|
+
:param score:
|
|
30
|
+
Score calculated by the model.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
entity: str
|
|
34
|
+
start: int
|
|
35
|
+
end: int
|
|
36
|
+
score: float | None = None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@component
|
|
40
|
+
class TransformersNamedEntityExtractor:
|
|
41
|
+
"""
|
|
42
|
+
Annotates named entities in a collection of documents.
|
|
43
|
+
|
|
44
|
+
The component can be used with any token classification model from the
|
|
45
|
+
[Hugging Face model hub](https://huggingface.co/models). Annotations are
|
|
46
|
+
stored as metadata in the documents.
|
|
47
|
+
|
|
48
|
+
Usage example:
|
|
49
|
+
```python
|
|
50
|
+
from haystack import Document
|
|
51
|
+
|
|
52
|
+
from haystack_integrations.components.extractors.transformers import TransformersNamedEntityExtractor
|
|
53
|
+
|
|
54
|
+
documents = [
|
|
55
|
+
Document(content="I'm Merlin, the happy pig!"),
|
|
56
|
+
Document(content="My name is Clara and I live in Berkeley, California."),
|
|
57
|
+
]
|
|
58
|
+
extractor = TransformersNamedEntityExtractor(model="dslim/bert-base-NER")
|
|
59
|
+
results = extractor.run(documents=documents)["documents"]
|
|
60
|
+
annotations = [TransformersNamedEntityExtractor.get_stored_annotations(doc) for doc in results]
|
|
61
|
+
print(annotations)
|
|
62
|
+
```
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
_METADATA_KEY = "named_entities"
|
|
66
|
+
|
|
67
|
+
def __init__(
|
|
68
|
+
self,
|
|
69
|
+
*,
|
|
70
|
+
model: str,
|
|
71
|
+
pipeline_kwargs: dict[str, Any] | None = None,
|
|
72
|
+
device: ComponentDevice | None = None,
|
|
73
|
+
token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
|
|
74
|
+
) -> None:
|
|
75
|
+
"""
|
|
76
|
+
Create a Named Entity extractor component.
|
|
77
|
+
|
|
78
|
+
:param model:
|
|
79
|
+
Name of the model or a path to the model on
|
|
80
|
+
the local disk.
|
|
81
|
+
:param pipeline_kwargs:
|
|
82
|
+
Keyword arguments passed to the pipeline. The
|
|
83
|
+
pipeline can override these arguments.
|
|
84
|
+
:param device:
|
|
85
|
+
The device on which the model is loaded. If `None`,
|
|
86
|
+
the default device is automatically selected. If a
|
|
87
|
+
device/device map is specified in `pipeline_kwargs`,
|
|
88
|
+
it overrides this parameter.
|
|
89
|
+
:param token:
|
|
90
|
+
The API token to download private models from Hugging Face.
|
|
91
|
+
"""
|
|
92
|
+
self.token = token
|
|
93
|
+
self.model_name_or_path = model
|
|
94
|
+
self.device = ComponentDevice.resolve_device(device)
|
|
95
|
+
|
|
96
|
+
self.pipeline_kwargs = _resolve_hf_pipeline_kwargs(
|
|
97
|
+
huggingface_pipeline_kwargs=pipeline_kwargs or {},
|
|
98
|
+
model=model,
|
|
99
|
+
task="ner",
|
|
100
|
+
supported_tasks=["ner"],
|
|
101
|
+
device=self.device,
|
|
102
|
+
token=token,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
self.tokenizer: Any = None
|
|
106
|
+
self.model: AutoModelForTokenClassification | None = None
|
|
107
|
+
self.pipeline: HfPipeline | None = None
|
|
108
|
+
self._warmed_up: bool = False
|
|
109
|
+
|
|
110
|
+
def warm_up(self) -> None:
|
|
111
|
+
"""
|
|
112
|
+
Initialize the component.
|
|
113
|
+
|
|
114
|
+
:raises ComponentError:
|
|
115
|
+
If the component fails to initialize successfully.
|
|
116
|
+
"""
|
|
117
|
+
if self._warmed_up:
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
try:
|
|
121
|
+
token = self.pipeline_kwargs.get("token", None)
|
|
122
|
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, token=token)
|
|
123
|
+
self.model = AutoModelForTokenClassification.from_pretrained(self.model_name_or_path, token=token)
|
|
124
|
+
|
|
125
|
+
pipeline_params: dict[str, Any] = {
|
|
126
|
+
"task": "ner",
|
|
127
|
+
"model": self.model,
|
|
128
|
+
"tokenizer": self.tokenizer,
|
|
129
|
+
"aggregation_strategy": "simple",
|
|
130
|
+
}
|
|
131
|
+
pipeline_params.update({k: v for k, v in self.pipeline_kwargs.items() if k not in pipeline_params})
|
|
132
|
+
self.device.update_hf_kwargs(pipeline_params, overwrite=False)
|
|
133
|
+
self.pipeline = pipeline(**pipeline_params)
|
|
134
|
+
self._warmed_up = True
|
|
135
|
+
except Exception as e:
|
|
136
|
+
msg = f"{self.__class__.__name__} failed to initialize."
|
|
137
|
+
raise ComponentError(msg) from e
|
|
138
|
+
|
|
139
|
+
@component.output_types(documents=list[Document])
|
|
140
|
+
def run(self, documents: list[Document], batch_size: int = 1) -> dict[str, Any]:
|
|
141
|
+
"""
|
|
142
|
+
Annotate named entities in each document and store the annotations in the document's metadata.
|
|
143
|
+
|
|
144
|
+
:param documents:
|
|
145
|
+
Documents to process.
|
|
146
|
+
:param batch_size:
|
|
147
|
+
Batch size used for processing the documents.
|
|
148
|
+
:returns:
|
|
149
|
+
Processed documents.
|
|
150
|
+
:raises ComponentError:
|
|
151
|
+
If the model fails to process a document.
|
|
152
|
+
"""
|
|
153
|
+
if not self._warmed_up:
|
|
154
|
+
self.warm_up()
|
|
155
|
+
|
|
156
|
+
texts = [doc.content if doc.content is not None else "" for doc in documents]
|
|
157
|
+
annotations = self._annotate(texts, batch_size=batch_size)
|
|
158
|
+
|
|
159
|
+
if len(annotations) != len(documents):
|
|
160
|
+
msg = (
|
|
161
|
+
"NER model did not return the correct number of annotations; "
|
|
162
|
+
f"got {len(annotations)} but expected {len(documents)}"
|
|
163
|
+
)
|
|
164
|
+
raise ComponentError(msg)
|
|
165
|
+
|
|
166
|
+
new_documents = []
|
|
167
|
+
for doc, doc_annotations in zip(documents, annotations, strict=True):
|
|
168
|
+
new_meta = {**doc.meta, self._METADATA_KEY: doc_annotations}
|
|
169
|
+
new_documents.append(replace(doc, meta=new_meta))
|
|
170
|
+
|
|
171
|
+
return {"documents": new_documents}
|
|
172
|
+
|
|
173
|
+
def _annotate(self, texts: list[str], *, batch_size: int = 1) -> list[list[NamedEntityAnnotation]]:
|
|
174
|
+
"""
|
|
175
|
+
Predict annotations for a collection of documents.
|
|
176
|
+
|
|
177
|
+
:param texts:
|
|
178
|
+
Raw texts to be annotated.
|
|
179
|
+
:param batch_size:
|
|
180
|
+
Size of text batches that are
|
|
181
|
+
passed to the model.
|
|
182
|
+
:returns:
|
|
183
|
+
NER annotations.
|
|
184
|
+
"""
|
|
185
|
+
if not self.initialized:
|
|
186
|
+
msg = "NER model was not initialized - Did you call `warm_up()`?"
|
|
187
|
+
raise ComponentError(msg)
|
|
188
|
+
|
|
189
|
+
assert self.pipeline is not None # noqa: S101
|
|
190
|
+
outputs = self.pipeline(texts, batch_size=batch_size)
|
|
191
|
+
return [
|
|
192
|
+
[
|
|
193
|
+
NamedEntityAnnotation(
|
|
194
|
+
entity=annotation["entity"] if "entity" in annotation else annotation["entity_group"],
|
|
195
|
+
start=annotation["start"],
|
|
196
|
+
end=annotation["end"],
|
|
197
|
+
score=annotation["score"],
|
|
198
|
+
)
|
|
199
|
+
for annotation in annotations
|
|
200
|
+
]
|
|
201
|
+
for annotations in outputs
|
|
202
|
+
]
|
|
203
|
+
|
|
204
|
+
def to_dict(self) -> dict[str, Any]:
|
|
205
|
+
"""
|
|
206
|
+
Serializes the component to a dictionary.
|
|
207
|
+
|
|
208
|
+
:returns:
|
|
209
|
+
Dictionary with serialized data.
|
|
210
|
+
"""
|
|
211
|
+
serialization_dict = default_to_dict(
|
|
212
|
+
self,
|
|
213
|
+
model=self.model_name_or_path,
|
|
214
|
+
device=self.device,
|
|
215
|
+
pipeline_kwargs=self.pipeline_kwargs,
|
|
216
|
+
token=self.token,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
hf_pipeline_kwargs = serialization_dict["init_parameters"]["pipeline_kwargs"]
|
|
220
|
+
hf_pipeline_kwargs.pop("token", None)
|
|
221
|
+
|
|
222
|
+
serialize_hf_model_kwargs(hf_pipeline_kwargs)
|
|
223
|
+
return serialization_dict
|
|
224
|
+
|
|
225
|
+
@classmethod
|
|
226
|
+
def from_dict(cls, data: dict[str, Any]) -> "TransformersNamedEntityExtractor":
|
|
227
|
+
"""
|
|
228
|
+
Deserializes the component from a dictionary.
|
|
229
|
+
|
|
230
|
+
:param data:
|
|
231
|
+
Dictionary to deserialize from.
|
|
232
|
+
:returns:
|
|
233
|
+
Deserialized component.
|
|
234
|
+
"""
|
|
235
|
+
try:
|
|
236
|
+
init_params = data.get("init_parameters", {})
|
|
237
|
+
hf_pipeline_kwargs = init_params.get("pipeline_kwargs")
|
|
238
|
+
deserialize_hf_model_kwargs(hf_pipeline_kwargs or {})
|
|
239
|
+
return default_from_dict(cls, data)
|
|
240
|
+
except Exception as e:
|
|
241
|
+
msg = f"Couldn't deserialize {cls.__name__} instance"
|
|
242
|
+
raise DeserializationError(msg) from e
|
|
243
|
+
|
|
244
|
+
@property
|
|
245
|
+
def initialized(self) -> bool:
|
|
246
|
+
"""
|
|
247
|
+
Returns if the extractor is ready to annotate text.
|
|
248
|
+
"""
|
|
249
|
+
return (self.tokenizer is not None and self.model is not None) or self.pipeline is not None
|
|
250
|
+
|
|
251
|
+
@classmethod
|
|
252
|
+
def get_stored_annotations(cls, document: Document) -> list[NamedEntityAnnotation] | None:
|
|
253
|
+
"""
|
|
254
|
+
Returns the document's named entity annotations stored in its metadata, if any.
|
|
255
|
+
|
|
256
|
+
:param document:
|
|
257
|
+
Document whose annotations are to be fetched.
|
|
258
|
+
:returns:
|
|
259
|
+
The stored annotations.
|
|
260
|
+
"""
|
|
261
|
+
|
|
262
|
+
return document.meta.get(cls._METADATA_KEY)
|
|
File without changes
|