transformers-haystack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. haystack_integrations/components/classifiers/py.typed +0 -0
  2. haystack_integrations/components/classifiers/transformers/__init__.py +6 -0
  3. haystack_integrations/components/classifiers/transformers/zero_shot_document_classifier.py +247 -0
  4. haystack_integrations/components/common/py.typed +0 -0
  5. haystack_integrations/components/common/transformers/__init__.py +3 -0
  6. haystack_integrations/components/common/transformers/utils.py +234 -0
  7. haystack_integrations/components/extractors/py.typed +0 -0
  8. haystack_integrations/components/extractors/transformers/__init__.py +6 -0
  9. haystack_integrations/components/extractors/transformers/named_entity_extractor.py +262 -0
  10. haystack_integrations/components/generators/py.typed +0 -0
  11. haystack_integrations/components/generators/transformers/__init__.py +6 -0
  12. haystack_integrations/components/generators/transformers/chat/__init__.py +3 -0
  13. haystack_integrations/components/generators/transformers/chat/chat_generator.py +666 -0
  14. haystack_integrations/components/readers/py.typed +0 -0
  15. haystack_integrations/components/readers/transformers/__init__.py +6 -0
  16. haystack_integrations/components/readers/transformers/extractive_reader.py +662 -0
  17. haystack_integrations/components/routers/py.typed +0 -0
  18. haystack_integrations/components/routers/transformers/__init__.py +7 -0
  19. haystack_integrations/components/routers/transformers/text_router.py +196 -0
  20. haystack_integrations/components/routers/transformers/zero_shot_text_router.py +205 -0
  21. transformers_haystack-0.1.0.dist-info/METADATA +38 -0
  22. transformers_haystack-0.1.0.dist-info/RECORD +24 -0
  23. transformers_haystack-0.1.0.dist-info/WHEEL +4 -0
  24. transformers_haystack-0.1.0.dist-info/licenses/LICENSE.txt +201 -0
File without changes
@@ -0,0 +1,6 @@
1
+ # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ from .zero_shot_document_classifier import TransformersZeroShotDocumentClassifier
5
+
6
+ __all__ = ["TransformersZeroShotDocumentClassifier"]
@@ -0,0 +1,247 @@
1
+ # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ from dataclasses import replace
6
+ from typing import Any
7
+
8
+ from haystack import Document, component, default_from_dict, default_to_dict
9
+ from haystack.utils import ComponentDevice, Secret
10
+ from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
11
+
12
+ from haystack_integrations.components.common.transformers.utils import _resolve_hf_pipeline_kwargs
13
+ from transformers import Pipeline as HfPipeline
14
+ from transformers import pipeline
15
+
16
+
17
+ @component
18
+ class TransformersZeroShotDocumentClassifier:
19
+ """
20
+ Performs zero-shot classification of documents based on given labels and adds the predicted label to their metadata.
21
+
22
+ The component uses a Hugging Face pipeline for zero-shot classification.
23
+ Provide the model and the set of labels to be used for categorization during initialization.
24
+ Additionally, you can configure the component to allow multiple labels to be true.
25
+
26
+ Classification is run on the document's content field by default. If you want it to run on another field, set the
27
+ `classification_field` to one of the document's metadata fields.
28
+
29
+ Available models for the task of zero-shot-classification include:
30
+ - `valhalla/distilbart-mnli-12-3`
31
+ - `cross-encoder/nli-distilroberta-base`
32
+ - `cross-encoder/nli-deberta-v3-xsmall`
33
+
34
+ ### Usage example
35
+
36
+ The following is a pipeline that classifies documents based on predefined classification labels
37
+ retrieved from a search pipeline:
38
+
39
+ ```python
40
+ from haystack import Document
41
+ from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
42
+ from haystack.core.pipeline import Pipeline
43
+ from haystack.document_stores.in_memory import InMemoryDocumentStore
44
+
45
+ from haystack_integrations.components.classifiers.transformers import TransformersZeroShotDocumentClassifier
46
+
47
+ documents = [Document(id="0", content="Today was a nice day!"),
48
+ Document(id="1", content="Yesterday was a bad day!")]
49
+
50
+ document_store = InMemoryDocumentStore()
51
+ retriever = InMemoryBM25Retriever(document_store=document_store)
52
+ document_classifier = TransformersZeroShotDocumentClassifier(
53
+ model="cross-encoder/nli-deberta-v3-xsmall",
54
+ labels=["positive", "negative"],
55
+ )
56
+
57
+ document_store.write_documents(documents)
58
+
59
+ pipeline = Pipeline()
60
+ pipeline.add_component(instance=retriever, name="retriever")
61
+ pipeline.add_component(instance=document_classifier, name="document_classifier")
62
+ pipeline.connect("retriever", "document_classifier")
63
+
64
+ queries = ["How was your day today?", "How was your day yesterday?"]
65
+ expected_predictions = ["positive", "negative"]
66
+
67
+ for idx, query in enumerate(queries):
68
+ result = pipeline.run({"retriever": {"query": query, "top_k": 1}})
69
+ assert result["document_classifier"]["documents"][0].to_dict()["id"] == str(idx)
70
+ assert (result["document_classifier"]["documents"][0].to_dict()["classification"]["label"]
71
+ == expected_predictions[idx])
72
+ ```
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ model: str,
78
+ labels: list[str],
79
+ multi_label: bool = False,
80
+ classification_field: str | None = None,
81
+ device: ComponentDevice | None = None,
82
+ token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
83
+ huggingface_pipeline_kwargs: dict[str, Any] | None = None,
84
+ ) -> None:
85
+ """
86
+ Initializes the TransformersZeroShotDocumentClassifier.
87
+
88
+ See the Hugging Face [website](https://huggingface.co/models?pipeline_tag=zero-shot-classification&sort=downloads&search=nli)
89
+ for the full list of zero-shot classification models (NLI) models.
90
+
91
+ :param model:
92
+ The name or path of a Hugging Face model for zero shot document classification.
93
+ :param labels:
94
+ The set of possible class labels to classify each document into, for example,
95
+ ["positive", "negative"]. The labels depend on the selected model.
96
+ :param multi_label:
97
+ Whether or not multiple candidate labels can be true.
98
+ If `False`, the scores are normalized such that
99
+ the sum of the label likelihoods for each sequence is 1. If `True`, the labels are considered
100
+ independent and probabilities are normalized for each candidate by doing a softmax of the entailment
101
+ score vs. the contradiction score.
102
+ :param classification_field:
103
+ Name of document's meta field to be used for classification.
104
+ If not set, `Document.content` is used by default.
105
+ :param device:
106
+ The device on which the model is loaded. If `None`, the default device is automatically
107
+ selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
108
+ :param token:
109
+ The Hugging Face token to use as HTTP bearer authorization.
110
+ Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
111
+ :param huggingface_pipeline_kwargs:
112
+ Dictionary containing keyword arguments used to initialize the
113
+ Hugging Face pipeline for text classification.
114
+ """
115
+
116
+ self.classification_field = classification_field
117
+
118
+ self.token = token
119
+ self.labels = labels
120
+ self.multi_label = multi_label
121
+
122
+ huggingface_pipeline_kwargs = _resolve_hf_pipeline_kwargs(
123
+ huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {},
124
+ model=model,
125
+ task="zero-shot-classification",
126
+ supported_tasks=["zero-shot-classification"],
127
+ device=device,
128
+ token=token,
129
+ )
130
+
131
+ self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
132
+ self.pipeline: HfPipeline | None = None
133
+
134
+ def _get_telemetry_data(self) -> dict[str, Any]:
135
+ """
136
+ Data that is sent to Posthog for usage analytics.
137
+ """
138
+ if isinstance(self.huggingface_pipeline_kwargs["model"], str):
139
+ return {"model": self.huggingface_pipeline_kwargs["model"]}
140
+ return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
141
+
142
+ def warm_up(self) -> None:
143
+ """
144
+ Initializes the component.
145
+ """
146
+ if self.pipeline is None:
147
+ self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
148
+
149
+ def to_dict(self) -> dict[str, Any]:
150
+ """
151
+ Serializes the component to a dictionary.
152
+
153
+ :returns:
154
+ Dictionary with serialized data.
155
+ """
156
+ serialization_dict = default_to_dict(
157
+ self,
158
+ labels=self.labels,
159
+ model=self.huggingface_pipeline_kwargs["model"],
160
+ huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
161
+ token=self.token,
162
+ multi_label=self.multi_label,
163
+ classification_field=self.classification_field,
164
+ )
165
+
166
+ huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
167
+ huggingface_pipeline_kwargs.pop("token", None)
168
+
169
+ serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
170
+ return serialization_dict
171
+
172
+ @classmethod
173
+ def from_dict(cls, data: dict[str, Any]) -> "TransformersZeroShotDocumentClassifier":
174
+ """
175
+ Deserializes the component from a dictionary.
176
+
177
+ :param data:
178
+ Dictionary to deserialize from.
179
+ :returns:
180
+ Deserialized component.
181
+ """
182
+ if data["init_parameters"].get("huggingface_pipeline_kwargs") is not None:
183
+ deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"])
184
+ return default_from_dict(cls, data)
185
+
186
+ @component.output_types(documents=list[Document])
187
+ def run(self, documents: list[Document], batch_size: int = 1) -> dict[str, Any]:
188
+ """
189
+ Classifies the documents based on the provided labels and adds them to their metadata.
190
+
191
+ The classification results are stored in the `classification` dict within
192
+ each document's metadata. If `multi_label` is set to `True`, the scores for each label are available under
193
+ the `details` key within the `classification` dictionary.
194
+
195
+ :param documents:
196
+ Documents to process.
197
+ :param batch_size:
198
+ Batch size used for processing the content in each document.
199
+ :returns:
200
+ A dictionary with the following key:
201
+ - `documents`: A list of documents with an added metadata field called `classification`.
202
+ """
203
+
204
+ if self.pipeline is None:
205
+ self.warm_up()
206
+
207
+ if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
208
+ msg = (
209
+ "TransformerZeroShotDocumentClassifier expects a list of documents as input. "
210
+ "In case you want to classify and route a text, please use the TransformersZeroShotTextRouter."
211
+ )
212
+ raise TypeError(msg)
213
+
214
+ invalid_doc_ids = []
215
+
216
+ for doc in documents:
217
+ if self.classification_field is not None and self.classification_field not in doc.meta:
218
+ invalid_doc_ids.append(doc.id)
219
+
220
+ if invalid_doc_ids:
221
+ msg = (
222
+ f"The following documents do not have the classification field '{self.classification_field}': "
223
+ f"{', '.join(invalid_doc_ids)}"
224
+ )
225
+ raise ValueError(msg)
226
+
227
+ texts = [
228
+ (doc.content if self.classification_field is None else doc.meta[self.classification_field])
229
+ for doc in documents
230
+ ]
231
+
232
+ # mypy doesn't know this is set in warm_up
233
+ predictions = self.pipeline( # type: ignore[misc]
234
+ texts, self.labels, multi_label=self.multi_label, batch_size=batch_size
235
+ )
236
+
237
+ new_documents = []
238
+ for prediction, document in zip(predictions, documents, strict=True):
239
+ formatted_prediction = {
240
+ "label": prediction["labels"][0],
241
+ "score": prediction["scores"][0],
242
+ "details": dict(zip(prediction["labels"], prediction["scores"], strict=True)),
243
+ }
244
+ new_meta = {**document.meta, "classification": formatted_prediction}
245
+ new_documents.append(replace(document, meta=new_meta))
246
+
247
+ return {"documents": new_documents}
File without changes
@@ -0,0 +1,3 @@
1
+ # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
@@ -0,0 +1,234 @@
1
+ # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import asyncio
6
+ import copy
7
+ from typing import Any
8
+
9
+ import torch
10
+ from haystack import logging
11
+ from haystack.dataclasses import AsyncStreamingCallbackT, ComponentInfo, StreamingChunk, SyncStreamingCallbackT
12
+ from haystack.utils.auth import Secret
13
+ from haystack.utils.device import ComponentDevice
14
+ from huggingface_hub import model_info
15
+
16
+ from transformers import (
17
+ PreTrainedTokenizer,
18
+ PreTrainedTokenizerBase,
19
+ PreTrainedTokenizerFast,
20
+ StoppingCriteria,
21
+ TextStreamer,
22
+ )
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ def _resolve_hf_device_map(device: ComponentDevice | None, model_kwargs: dict[str, Any] | None) -> dict[str, Any]:
28
+ """
29
+ Update `model_kwargs` to include the keyword argument `device_map`.
30
+
31
+ This method is useful you want to force loading a transformers model when using `AutoModel.from_pretrained` to
32
+ use `device_map`.
33
+
34
+ We handle the edge case where `device` and `device_map` is specified by ignoring the `device` parameter and printing
35
+ a warning.
36
+
37
+ :param device: The device on which the model is loaded. If `None`, the default device is automatically
38
+ selected.
39
+ :param model_kwargs: Additional HF keyword arguments passed to `AutoModel.from_pretrained`.
40
+ For details on what kwargs you can pass, see the model's documentation.
41
+ """
42
+ model_kwargs = copy.copy(model_kwargs) or {}
43
+ if model_kwargs.get("device_map"):
44
+ if device is not None:
45
+ logger.warning(
46
+ "The parameters `device` and `device_map` from `model_kwargs` are both provided. "
47
+ "Ignoring `device` and using `device_map`."
48
+ )
49
+ # Resolve device if device_map is provided in model_kwargs
50
+ device_map = model_kwargs["device_map"]
51
+ else:
52
+ device_map = ComponentDevice.resolve_device(device).to_hf()
53
+
54
+ # Set up device_map which allows quantized loading and multi device inference
55
+ # requires accelerate which is always installed when using `pip install transformers[torch]`
56
+ model_kwargs["device_map"] = device_map
57
+
58
+ return model_kwargs
59
+
60
+
61
+ def _resolve_hf_pipeline_kwargs(
62
+ huggingface_pipeline_kwargs: dict[str, Any],
63
+ model: str,
64
+ task: str | None,
65
+ supported_tasks: list[str],
66
+ device: ComponentDevice | None,
67
+ token: Secret | None,
68
+ ) -> dict[str, Any]:
69
+ """
70
+ Resolve the HuggingFace pipeline keyword arguments based on explicit user inputs.
71
+
72
+ :param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize a
73
+ Hugging Face pipeline.
74
+ :param model: The name or path of a Hugging Face model for on the HuggingFace Hub.
75
+ :param task: The task for the Hugging Face pipeline.
76
+ :param supported_tasks: The list of supported tasks to check the task of the model against. If the task of the model
77
+ is not present within this list then a ValueError is thrown.
78
+ :param device: The device on which the model is loaded. If `None`, the default device is automatically
79
+ selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
80
+ :param token: The token to use as HTTP bearer authorization for remote files.
81
+ If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
82
+ """
83
+ resolved_token = token.resolve_value() if token else None
84
+ # check if the huggingface_pipeline_kwargs contain the essential parameters
85
+ # otherwise, populate them with values from other init parameters
86
+ huggingface_pipeline_kwargs.setdefault("model", model)
87
+ huggingface_pipeline_kwargs.setdefault("token", resolved_token)
88
+
89
+ resolved_device = ComponentDevice.resolve_device(device)
90
+ resolved_device.update_hf_kwargs(huggingface_pipeline_kwargs, overwrite=False)
91
+
92
+ # task identification and validation
93
+ task = task or huggingface_pipeline_kwargs.get("task")
94
+ if task is None and isinstance(huggingface_pipeline_kwargs["model"], str):
95
+ task = model_info(huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"]).pipeline_tag
96
+
97
+ if task not in supported_tasks:
98
+ msg = f"Task '{task}' is not supported. The supported tasks are: {', '.join(supported_tasks)}."
99
+ raise ValueError(msg)
100
+ huggingface_pipeline_kwargs["task"] = task
101
+ return huggingface_pipeline_kwargs
102
+
103
+
104
+ class _StopWordsCriteria(StoppingCriteria):
105
+ """
106
+ Stops text generation in HuggingFace generators if any one of the stop words is generated.
107
+
108
+ Note: When a stop word is encountered, the generation of new text is stopped.
109
+ However, if the stop word is in the prompt itself, it can stop generating new text
110
+ prematurely after the first token. This is particularly important for LLMs designed
111
+ for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat,
112
+ the output includes both the new text and the original prompt. Therefore, it's important
113
+ to make sure your prompt has no stop words.
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
119
+ stop_words: list[str],
120
+ device: str | torch.device = "cpu",
121
+ ) -> None:
122
+ """Creates an instance of _StopWordsCriteria."""
123
+ super().__init__()
124
+ # check if tokenizer is a valid tokenizer
125
+ if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
126
+ msg = (
127
+ f"Invalid tokenizer provided for _StopWordsCriteria - {tokenizer}. "
128
+ f"Please provide a valid tokenizer from the HuggingFace Transformers library."
129
+ )
130
+ raise TypeError(msg)
131
+ if not tokenizer.pad_token:
132
+ if tokenizer.eos_token:
133
+ tokenizer.pad_token = tokenizer.eos_token
134
+ else:
135
+ tokenizer.add_special_tokens({"pad_token": "[PAD]"})
136
+ encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt")
137
+ self.stop_ids = encoded_stop_words.input_ids.to(device)
138
+
139
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs: Any) -> bool: # noqa: ARG002
140
+ """Check if any of the stop words are generated in the current text generation step."""
141
+ for stop_id in self.stop_ids:
142
+ found_stop_word = self.is_stop_word_found(input_ids, stop_id)
143
+ if found_stop_word:
144
+ return True
145
+ return False
146
+
147
+ @staticmethod
148
+ def is_stop_word_found(generated_text_ids: torch.Tensor, stop_id: torch.Tensor) -> bool:
149
+ """
150
+ Performs phrase matching.
151
+
152
+ Checks if a sequence of stop tokens appears in a continuous or sequential order within the generated text.
153
+ """
154
+ generated_text_ids = generated_text_ids[-1]
155
+ len_generated_text_ids = generated_text_ids.size(0)
156
+ len_stop_id = stop_id.size(0)
157
+ return all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id))
158
+
159
+
160
+ class _HFTokenStreamingHandler(TextStreamer):
161
+ """
162
+ Streaming handler for TransformersChatGenerator.
163
+
164
+ Note: This is a helper class for TransformersChatGenerator enabling streaming
165
+ of generated text via Haystack SyncStreamingCallbackT callbacks.
166
+
167
+ Do not use this class directly.
168
+ """
169
+
170
+ def __init__(
171
+ self,
172
+ tokenizer: PreTrainedTokenizerBase,
173
+ stream_handler: SyncStreamingCallbackT,
174
+ stop_words: list[str] | None = None,
175
+ component_info: ComponentInfo | None = None,
176
+ ) -> None:
177
+ """Creates an instance of _HFTokenStreamingHandler."""
178
+ super().__init__(tokenizer=tokenizer, skip_prompt=True)
179
+ self.token_handler = stream_handler
180
+ self.stop_words = stop_words or []
181
+ self.component_info = component_info
182
+ self._call_counter = 0
183
+
184
+ def on_finalized_text(self, word: str, stream_end: bool = False) -> None:
185
+ """Callback function for handling the generated text."""
186
+ self._call_counter += 1
187
+ word_to_send = word + "\n" if stream_end else word
188
+ if word_to_send.strip() not in self.stop_words:
189
+ self.token_handler(
190
+ StreamingChunk(
191
+ content=word_to_send, index=0, start=self._call_counter == 1, component_info=self.component_info
192
+ )
193
+ )
194
+
195
+
196
+ class _AsyncHFTokenStreamingHandler(TextStreamer):
197
+ """
198
+ Async streaming handler for TransformersChatGenerator.
199
+
200
+ Note: This is a helper class for TransformersChatGenerator enabling
201
+ async streaming of generated text via Haystack Callable[StreamingChunk, Awaitable[None]] callbacks.
202
+
203
+ Do not use this class directly.
204
+ """
205
+
206
+ def __init__(
207
+ self,
208
+ tokenizer: PreTrainedTokenizerBase,
209
+ stream_handler: AsyncStreamingCallbackT,
210
+ stop_words: list[str] | None = None,
211
+ component_info: ComponentInfo | None = None,
212
+ ) -> None:
213
+ """Creates an instance of _AsyncHFTokenStreamingHandler."""
214
+ super().__init__(tokenizer=tokenizer, skip_prompt=True)
215
+ self.token_handler = stream_handler
216
+ self.stop_words = stop_words or []
217
+ self.component_info = component_info
218
+ self._queue: asyncio.Queue[StreamingChunk] = asyncio.Queue()
219
+
220
+ def on_finalized_text(self, word: str, stream_end: bool = False) -> None:
221
+ """Synchronous callback that puts chunks in a queue."""
222
+ word_to_send = word + "\n" if stream_end else word
223
+ if word_to_send.strip() not in self.stop_words:
224
+ self._queue.put_nowait(StreamingChunk(content=word_to_send, component_info=self.component_info))
225
+
226
+ async def process_queue(self) -> None:
227
+ """Process the queue of streaming chunks."""
228
+ while True:
229
+ try:
230
+ chunk = await self._queue.get()
231
+ await self.token_handler(chunk)
232
+ self._queue.task_done()
233
+ except asyncio.CancelledError:
234
+ break
File without changes
@@ -0,0 +1,6 @@
1
+ # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ from .named_entity_extractor import NamedEntityAnnotation, TransformersNamedEntityExtractor
5
+
6
+ __all__ = ["NamedEntityAnnotation", "TransformersNamedEntityExtractor"]