transformers-haystack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- haystack_integrations/components/classifiers/py.typed +0 -0
- haystack_integrations/components/classifiers/transformers/__init__.py +6 -0
- haystack_integrations/components/classifiers/transformers/zero_shot_document_classifier.py +247 -0
- haystack_integrations/components/common/py.typed +0 -0
- haystack_integrations/components/common/transformers/__init__.py +3 -0
- haystack_integrations/components/common/transformers/utils.py +234 -0
- haystack_integrations/components/extractors/py.typed +0 -0
- haystack_integrations/components/extractors/transformers/__init__.py +6 -0
- haystack_integrations/components/extractors/transformers/named_entity_extractor.py +262 -0
- haystack_integrations/components/generators/py.typed +0 -0
- haystack_integrations/components/generators/transformers/__init__.py +6 -0
- haystack_integrations/components/generators/transformers/chat/__init__.py +3 -0
- haystack_integrations/components/generators/transformers/chat/chat_generator.py +666 -0
- haystack_integrations/components/readers/py.typed +0 -0
- haystack_integrations/components/readers/transformers/__init__.py +6 -0
- haystack_integrations/components/readers/transformers/extractive_reader.py +662 -0
- haystack_integrations/components/routers/py.typed +0 -0
- haystack_integrations/components/routers/transformers/__init__.py +7 -0
- haystack_integrations/components/routers/transformers/text_router.py +196 -0
- haystack_integrations/components/routers/transformers/zero_shot_text_router.py +205 -0
- transformers_haystack-0.1.0.dist-info/METADATA +38 -0
- transformers_haystack-0.1.0.dist-info/RECORD +24 -0
- transformers_haystack-0.1.0.dist-info/WHEEL +4 -0
- transformers_haystack-0.1.0.dist-info/licenses/LICENSE.txt +201 -0
|
File without changes
|
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
from dataclasses import replace
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from haystack import Document, component, default_from_dict, default_to_dict
|
|
9
|
+
from haystack.utils import ComponentDevice, Secret
|
|
10
|
+
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
|
|
11
|
+
|
|
12
|
+
from haystack_integrations.components.common.transformers.utils import _resolve_hf_pipeline_kwargs
|
|
13
|
+
from transformers import Pipeline as HfPipeline
|
|
14
|
+
from transformers import pipeline
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@component
|
|
18
|
+
class TransformersZeroShotDocumentClassifier:
|
|
19
|
+
"""
|
|
20
|
+
Performs zero-shot classification of documents based on given labels and adds the predicted label to their metadata.
|
|
21
|
+
|
|
22
|
+
The component uses a Hugging Face pipeline for zero-shot classification.
|
|
23
|
+
Provide the model and the set of labels to be used for categorization during initialization.
|
|
24
|
+
Additionally, you can configure the component to allow multiple labels to be true.
|
|
25
|
+
|
|
26
|
+
Classification is run on the document's content field by default. If you want it to run on another field, set the
|
|
27
|
+
`classification_field` to one of the document's metadata fields.
|
|
28
|
+
|
|
29
|
+
Available models for the task of zero-shot-classification include:
|
|
30
|
+
- `valhalla/distilbart-mnli-12-3`
|
|
31
|
+
- `cross-encoder/nli-distilroberta-base`
|
|
32
|
+
- `cross-encoder/nli-deberta-v3-xsmall`
|
|
33
|
+
|
|
34
|
+
### Usage example
|
|
35
|
+
|
|
36
|
+
The following is a pipeline that classifies documents based on predefined classification labels
|
|
37
|
+
retrieved from a search pipeline:
|
|
38
|
+
|
|
39
|
+
```python
|
|
40
|
+
from haystack import Document
|
|
41
|
+
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever
|
|
42
|
+
from haystack.core.pipeline import Pipeline
|
|
43
|
+
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
|
44
|
+
|
|
45
|
+
from haystack_integrations.components.classifiers.transformers import TransformersZeroShotDocumentClassifier
|
|
46
|
+
|
|
47
|
+
documents = [Document(id="0", content="Today was a nice day!"),
|
|
48
|
+
Document(id="1", content="Yesterday was a bad day!")]
|
|
49
|
+
|
|
50
|
+
document_store = InMemoryDocumentStore()
|
|
51
|
+
retriever = InMemoryBM25Retriever(document_store=document_store)
|
|
52
|
+
document_classifier = TransformersZeroShotDocumentClassifier(
|
|
53
|
+
model="cross-encoder/nli-deberta-v3-xsmall",
|
|
54
|
+
labels=["positive", "negative"],
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
document_store.write_documents(documents)
|
|
58
|
+
|
|
59
|
+
pipeline = Pipeline()
|
|
60
|
+
pipeline.add_component(instance=retriever, name="retriever")
|
|
61
|
+
pipeline.add_component(instance=document_classifier, name="document_classifier")
|
|
62
|
+
pipeline.connect("retriever", "document_classifier")
|
|
63
|
+
|
|
64
|
+
queries = ["How was your day today?", "How was your day yesterday?"]
|
|
65
|
+
expected_predictions = ["positive", "negative"]
|
|
66
|
+
|
|
67
|
+
for idx, query in enumerate(queries):
|
|
68
|
+
result = pipeline.run({"retriever": {"query": query, "top_k": 1}})
|
|
69
|
+
assert result["document_classifier"]["documents"][0].to_dict()["id"] == str(idx)
|
|
70
|
+
assert (result["document_classifier"]["documents"][0].to_dict()["classification"]["label"]
|
|
71
|
+
== expected_predictions[idx])
|
|
72
|
+
```
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
def __init__(
|
|
76
|
+
self,
|
|
77
|
+
model: str,
|
|
78
|
+
labels: list[str],
|
|
79
|
+
multi_label: bool = False,
|
|
80
|
+
classification_field: str | None = None,
|
|
81
|
+
device: ComponentDevice | None = None,
|
|
82
|
+
token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
|
|
83
|
+
huggingface_pipeline_kwargs: dict[str, Any] | None = None,
|
|
84
|
+
) -> None:
|
|
85
|
+
"""
|
|
86
|
+
Initializes the TransformersZeroShotDocumentClassifier.
|
|
87
|
+
|
|
88
|
+
See the Hugging Face [website](https://huggingface.co/models?pipeline_tag=zero-shot-classification&sort=downloads&search=nli)
|
|
89
|
+
for the full list of zero-shot classification models (NLI) models.
|
|
90
|
+
|
|
91
|
+
:param model:
|
|
92
|
+
The name or path of a Hugging Face model for zero shot document classification.
|
|
93
|
+
:param labels:
|
|
94
|
+
The set of possible class labels to classify each document into, for example,
|
|
95
|
+
["positive", "negative"]. The labels depend on the selected model.
|
|
96
|
+
:param multi_label:
|
|
97
|
+
Whether or not multiple candidate labels can be true.
|
|
98
|
+
If `False`, the scores are normalized such that
|
|
99
|
+
the sum of the label likelihoods for each sequence is 1. If `True`, the labels are considered
|
|
100
|
+
independent and probabilities are normalized for each candidate by doing a softmax of the entailment
|
|
101
|
+
score vs. the contradiction score.
|
|
102
|
+
:param classification_field:
|
|
103
|
+
Name of document's meta field to be used for classification.
|
|
104
|
+
If not set, `Document.content` is used by default.
|
|
105
|
+
:param device:
|
|
106
|
+
The device on which the model is loaded. If `None`, the default device is automatically
|
|
107
|
+
selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
|
|
108
|
+
:param token:
|
|
109
|
+
The Hugging Face token to use as HTTP bearer authorization.
|
|
110
|
+
Check your HF token in your [account settings](https://huggingface.co/settings/tokens).
|
|
111
|
+
:param huggingface_pipeline_kwargs:
|
|
112
|
+
Dictionary containing keyword arguments used to initialize the
|
|
113
|
+
Hugging Face pipeline for text classification.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
self.classification_field = classification_field
|
|
117
|
+
|
|
118
|
+
self.token = token
|
|
119
|
+
self.labels = labels
|
|
120
|
+
self.multi_label = multi_label
|
|
121
|
+
|
|
122
|
+
huggingface_pipeline_kwargs = _resolve_hf_pipeline_kwargs(
|
|
123
|
+
huggingface_pipeline_kwargs=huggingface_pipeline_kwargs or {},
|
|
124
|
+
model=model,
|
|
125
|
+
task="zero-shot-classification",
|
|
126
|
+
supported_tasks=["zero-shot-classification"],
|
|
127
|
+
device=device,
|
|
128
|
+
token=token,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
self.huggingface_pipeline_kwargs = huggingface_pipeline_kwargs
|
|
132
|
+
self.pipeline: HfPipeline | None = None
|
|
133
|
+
|
|
134
|
+
def _get_telemetry_data(self) -> dict[str, Any]:
|
|
135
|
+
"""
|
|
136
|
+
Data that is sent to Posthog for usage analytics.
|
|
137
|
+
"""
|
|
138
|
+
if isinstance(self.huggingface_pipeline_kwargs["model"], str):
|
|
139
|
+
return {"model": self.huggingface_pipeline_kwargs["model"]}
|
|
140
|
+
return {"model": f"[object of type {type(self.huggingface_pipeline_kwargs['model'])}]"}
|
|
141
|
+
|
|
142
|
+
def warm_up(self) -> None:
|
|
143
|
+
"""
|
|
144
|
+
Initializes the component.
|
|
145
|
+
"""
|
|
146
|
+
if self.pipeline is None:
|
|
147
|
+
self.pipeline = pipeline(**self.huggingface_pipeline_kwargs)
|
|
148
|
+
|
|
149
|
+
def to_dict(self) -> dict[str, Any]:
|
|
150
|
+
"""
|
|
151
|
+
Serializes the component to a dictionary.
|
|
152
|
+
|
|
153
|
+
:returns:
|
|
154
|
+
Dictionary with serialized data.
|
|
155
|
+
"""
|
|
156
|
+
serialization_dict = default_to_dict(
|
|
157
|
+
self,
|
|
158
|
+
labels=self.labels,
|
|
159
|
+
model=self.huggingface_pipeline_kwargs["model"],
|
|
160
|
+
huggingface_pipeline_kwargs=self.huggingface_pipeline_kwargs,
|
|
161
|
+
token=self.token,
|
|
162
|
+
multi_label=self.multi_label,
|
|
163
|
+
classification_field=self.classification_field,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
huggingface_pipeline_kwargs = serialization_dict["init_parameters"]["huggingface_pipeline_kwargs"]
|
|
167
|
+
huggingface_pipeline_kwargs.pop("token", None)
|
|
168
|
+
|
|
169
|
+
serialize_hf_model_kwargs(huggingface_pipeline_kwargs)
|
|
170
|
+
return serialization_dict
|
|
171
|
+
|
|
172
|
+
@classmethod
|
|
173
|
+
def from_dict(cls, data: dict[str, Any]) -> "TransformersZeroShotDocumentClassifier":
|
|
174
|
+
"""
|
|
175
|
+
Deserializes the component from a dictionary.
|
|
176
|
+
|
|
177
|
+
:param data:
|
|
178
|
+
Dictionary to deserialize from.
|
|
179
|
+
:returns:
|
|
180
|
+
Deserialized component.
|
|
181
|
+
"""
|
|
182
|
+
if data["init_parameters"].get("huggingface_pipeline_kwargs") is not None:
|
|
183
|
+
deserialize_hf_model_kwargs(data["init_parameters"]["huggingface_pipeline_kwargs"])
|
|
184
|
+
return default_from_dict(cls, data)
|
|
185
|
+
|
|
186
|
+
@component.output_types(documents=list[Document])
|
|
187
|
+
def run(self, documents: list[Document], batch_size: int = 1) -> dict[str, Any]:
|
|
188
|
+
"""
|
|
189
|
+
Classifies the documents based on the provided labels and adds them to their metadata.
|
|
190
|
+
|
|
191
|
+
The classification results are stored in the `classification` dict within
|
|
192
|
+
each document's metadata. If `multi_label` is set to `True`, the scores for each label are available under
|
|
193
|
+
the `details` key within the `classification` dictionary.
|
|
194
|
+
|
|
195
|
+
:param documents:
|
|
196
|
+
Documents to process.
|
|
197
|
+
:param batch_size:
|
|
198
|
+
Batch size used for processing the content in each document.
|
|
199
|
+
:returns:
|
|
200
|
+
A dictionary with the following key:
|
|
201
|
+
- `documents`: A list of documents with an added metadata field called `classification`.
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
if self.pipeline is None:
|
|
205
|
+
self.warm_up()
|
|
206
|
+
|
|
207
|
+
if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
|
|
208
|
+
msg = (
|
|
209
|
+
"TransformerZeroShotDocumentClassifier expects a list of documents as input. "
|
|
210
|
+
"In case you want to classify and route a text, please use the TransformersZeroShotTextRouter."
|
|
211
|
+
)
|
|
212
|
+
raise TypeError(msg)
|
|
213
|
+
|
|
214
|
+
invalid_doc_ids = []
|
|
215
|
+
|
|
216
|
+
for doc in documents:
|
|
217
|
+
if self.classification_field is not None and self.classification_field not in doc.meta:
|
|
218
|
+
invalid_doc_ids.append(doc.id)
|
|
219
|
+
|
|
220
|
+
if invalid_doc_ids:
|
|
221
|
+
msg = (
|
|
222
|
+
f"The following documents do not have the classification field '{self.classification_field}': "
|
|
223
|
+
f"{', '.join(invalid_doc_ids)}"
|
|
224
|
+
)
|
|
225
|
+
raise ValueError(msg)
|
|
226
|
+
|
|
227
|
+
texts = [
|
|
228
|
+
(doc.content if self.classification_field is None else doc.meta[self.classification_field])
|
|
229
|
+
for doc in documents
|
|
230
|
+
]
|
|
231
|
+
|
|
232
|
+
# mypy doesn't know this is set in warm_up
|
|
233
|
+
predictions = self.pipeline( # type: ignore[misc]
|
|
234
|
+
texts, self.labels, multi_label=self.multi_label, batch_size=batch_size
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
new_documents = []
|
|
238
|
+
for prediction, document in zip(predictions, documents, strict=True):
|
|
239
|
+
formatted_prediction = {
|
|
240
|
+
"label": prediction["labels"][0],
|
|
241
|
+
"score": prediction["scores"][0],
|
|
242
|
+
"details": dict(zip(prediction["labels"], prediction["scores"], strict=True)),
|
|
243
|
+
}
|
|
244
|
+
new_meta = {**document.meta, "classification": formatted_prediction}
|
|
245
|
+
new_documents.append(replace(document, meta=new_meta))
|
|
246
|
+
|
|
247
|
+
return {"documents": new_documents}
|
|
File without changes
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import copy
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from haystack import logging
|
|
11
|
+
from haystack.dataclasses import AsyncStreamingCallbackT, ComponentInfo, StreamingChunk, SyncStreamingCallbackT
|
|
12
|
+
from haystack.utils.auth import Secret
|
|
13
|
+
from haystack.utils.device import ComponentDevice
|
|
14
|
+
from huggingface_hub import model_info
|
|
15
|
+
|
|
16
|
+
from transformers import (
|
|
17
|
+
PreTrainedTokenizer,
|
|
18
|
+
PreTrainedTokenizerBase,
|
|
19
|
+
PreTrainedTokenizerFast,
|
|
20
|
+
StoppingCriteria,
|
|
21
|
+
TextStreamer,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _resolve_hf_device_map(device: ComponentDevice | None, model_kwargs: dict[str, Any] | None) -> dict[str, Any]:
|
|
28
|
+
"""
|
|
29
|
+
Update `model_kwargs` to include the keyword argument `device_map`.
|
|
30
|
+
|
|
31
|
+
This method is useful you want to force loading a transformers model when using `AutoModel.from_pretrained` to
|
|
32
|
+
use `device_map`.
|
|
33
|
+
|
|
34
|
+
We handle the edge case where `device` and `device_map` is specified by ignoring the `device` parameter and printing
|
|
35
|
+
a warning.
|
|
36
|
+
|
|
37
|
+
:param device: The device on which the model is loaded. If `None`, the default device is automatically
|
|
38
|
+
selected.
|
|
39
|
+
:param model_kwargs: Additional HF keyword arguments passed to `AutoModel.from_pretrained`.
|
|
40
|
+
For details on what kwargs you can pass, see the model's documentation.
|
|
41
|
+
"""
|
|
42
|
+
model_kwargs = copy.copy(model_kwargs) or {}
|
|
43
|
+
if model_kwargs.get("device_map"):
|
|
44
|
+
if device is not None:
|
|
45
|
+
logger.warning(
|
|
46
|
+
"The parameters `device` and `device_map` from `model_kwargs` are both provided. "
|
|
47
|
+
"Ignoring `device` and using `device_map`."
|
|
48
|
+
)
|
|
49
|
+
# Resolve device if device_map is provided in model_kwargs
|
|
50
|
+
device_map = model_kwargs["device_map"]
|
|
51
|
+
else:
|
|
52
|
+
device_map = ComponentDevice.resolve_device(device).to_hf()
|
|
53
|
+
|
|
54
|
+
# Set up device_map which allows quantized loading and multi device inference
|
|
55
|
+
# requires accelerate which is always installed when using `pip install transformers[torch]`
|
|
56
|
+
model_kwargs["device_map"] = device_map
|
|
57
|
+
|
|
58
|
+
return model_kwargs
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _resolve_hf_pipeline_kwargs(
|
|
62
|
+
huggingface_pipeline_kwargs: dict[str, Any],
|
|
63
|
+
model: str,
|
|
64
|
+
task: str | None,
|
|
65
|
+
supported_tasks: list[str],
|
|
66
|
+
device: ComponentDevice | None,
|
|
67
|
+
token: Secret | None,
|
|
68
|
+
) -> dict[str, Any]:
|
|
69
|
+
"""
|
|
70
|
+
Resolve the HuggingFace pipeline keyword arguments based on explicit user inputs.
|
|
71
|
+
|
|
72
|
+
:param huggingface_pipeline_kwargs: Dictionary containing keyword arguments used to initialize a
|
|
73
|
+
Hugging Face pipeline.
|
|
74
|
+
:param model: The name or path of a Hugging Face model for on the HuggingFace Hub.
|
|
75
|
+
:param task: The task for the Hugging Face pipeline.
|
|
76
|
+
:param supported_tasks: The list of supported tasks to check the task of the model against. If the task of the model
|
|
77
|
+
is not present within this list then a ValueError is thrown.
|
|
78
|
+
:param device: The device on which the model is loaded. If `None`, the default device is automatically
|
|
79
|
+
selected. If a device/device map is specified in `huggingface_pipeline_kwargs`, it overrides this parameter.
|
|
80
|
+
:param token: The token to use as HTTP bearer authorization for remote files.
|
|
81
|
+
If the token is also specified in the `huggingface_pipeline_kwargs`, this parameter will be ignored.
|
|
82
|
+
"""
|
|
83
|
+
resolved_token = token.resolve_value() if token else None
|
|
84
|
+
# check if the huggingface_pipeline_kwargs contain the essential parameters
|
|
85
|
+
# otherwise, populate them with values from other init parameters
|
|
86
|
+
huggingface_pipeline_kwargs.setdefault("model", model)
|
|
87
|
+
huggingface_pipeline_kwargs.setdefault("token", resolved_token)
|
|
88
|
+
|
|
89
|
+
resolved_device = ComponentDevice.resolve_device(device)
|
|
90
|
+
resolved_device.update_hf_kwargs(huggingface_pipeline_kwargs, overwrite=False)
|
|
91
|
+
|
|
92
|
+
# task identification and validation
|
|
93
|
+
task = task or huggingface_pipeline_kwargs.get("task")
|
|
94
|
+
if task is None and isinstance(huggingface_pipeline_kwargs["model"], str):
|
|
95
|
+
task = model_info(huggingface_pipeline_kwargs["model"], token=huggingface_pipeline_kwargs["token"]).pipeline_tag
|
|
96
|
+
|
|
97
|
+
if task not in supported_tasks:
|
|
98
|
+
msg = f"Task '{task}' is not supported. The supported tasks are: {', '.join(supported_tasks)}."
|
|
99
|
+
raise ValueError(msg)
|
|
100
|
+
huggingface_pipeline_kwargs["task"] = task
|
|
101
|
+
return huggingface_pipeline_kwargs
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class _StopWordsCriteria(StoppingCriteria):
|
|
105
|
+
"""
|
|
106
|
+
Stops text generation in HuggingFace generators if any one of the stop words is generated.
|
|
107
|
+
|
|
108
|
+
Note: When a stop word is encountered, the generation of new text is stopped.
|
|
109
|
+
However, if the stop word is in the prompt itself, it can stop generating new text
|
|
110
|
+
prematurely after the first token. This is particularly important for LLMs designed
|
|
111
|
+
for dialogue generation. For these models, like for example mosaicml/mpt-7b-chat,
|
|
112
|
+
the output includes both the new text and the original prompt. Therefore, it's important
|
|
113
|
+
to make sure your prompt has no stop words.
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
def __init__(
|
|
117
|
+
self,
|
|
118
|
+
tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast,
|
|
119
|
+
stop_words: list[str],
|
|
120
|
+
device: str | torch.device = "cpu",
|
|
121
|
+
) -> None:
|
|
122
|
+
"""Creates an instance of _StopWordsCriteria."""
|
|
123
|
+
super().__init__()
|
|
124
|
+
# check if tokenizer is a valid tokenizer
|
|
125
|
+
if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
|
|
126
|
+
msg = (
|
|
127
|
+
f"Invalid tokenizer provided for _StopWordsCriteria - {tokenizer}. "
|
|
128
|
+
f"Please provide a valid tokenizer from the HuggingFace Transformers library."
|
|
129
|
+
)
|
|
130
|
+
raise TypeError(msg)
|
|
131
|
+
if not tokenizer.pad_token:
|
|
132
|
+
if tokenizer.eos_token:
|
|
133
|
+
tokenizer.pad_token = tokenizer.eos_token
|
|
134
|
+
else:
|
|
135
|
+
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
|
136
|
+
encoded_stop_words = tokenizer(stop_words, add_special_tokens=False, padding=True, return_tensors="pt")
|
|
137
|
+
self.stop_ids = encoded_stop_words.input_ids.to(device)
|
|
138
|
+
|
|
139
|
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs: Any) -> bool: # noqa: ARG002
|
|
140
|
+
"""Check if any of the stop words are generated in the current text generation step."""
|
|
141
|
+
for stop_id in self.stop_ids:
|
|
142
|
+
found_stop_word = self.is_stop_word_found(input_ids, stop_id)
|
|
143
|
+
if found_stop_word:
|
|
144
|
+
return True
|
|
145
|
+
return False
|
|
146
|
+
|
|
147
|
+
@staticmethod
|
|
148
|
+
def is_stop_word_found(generated_text_ids: torch.Tensor, stop_id: torch.Tensor) -> bool:
|
|
149
|
+
"""
|
|
150
|
+
Performs phrase matching.
|
|
151
|
+
|
|
152
|
+
Checks if a sequence of stop tokens appears in a continuous or sequential order within the generated text.
|
|
153
|
+
"""
|
|
154
|
+
generated_text_ids = generated_text_ids[-1]
|
|
155
|
+
len_generated_text_ids = generated_text_ids.size(0)
|
|
156
|
+
len_stop_id = stop_id.size(0)
|
|
157
|
+
return all(generated_text_ids[len_generated_text_ids - len_stop_id :].eq(stop_id))
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
class _HFTokenStreamingHandler(TextStreamer):
|
|
161
|
+
"""
|
|
162
|
+
Streaming handler for TransformersChatGenerator.
|
|
163
|
+
|
|
164
|
+
Note: This is a helper class for TransformersChatGenerator enabling streaming
|
|
165
|
+
of generated text via Haystack SyncStreamingCallbackT callbacks.
|
|
166
|
+
|
|
167
|
+
Do not use this class directly.
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
def __init__(
|
|
171
|
+
self,
|
|
172
|
+
tokenizer: PreTrainedTokenizerBase,
|
|
173
|
+
stream_handler: SyncStreamingCallbackT,
|
|
174
|
+
stop_words: list[str] | None = None,
|
|
175
|
+
component_info: ComponentInfo | None = None,
|
|
176
|
+
) -> None:
|
|
177
|
+
"""Creates an instance of _HFTokenStreamingHandler."""
|
|
178
|
+
super().__init__(tokenizer=tokenizer, skip_prompt=True)
|
|
179
|
+
self.token_handler = stream_handler
|
|
180
|
+
self.stop_words = stop_words or []
|
|
181
|
+
self.component_info = component_info
|
|
182
|
+
self._call_counter = 0
|
|
183
|
+
|
|
184
|
+
def on_finalized_text(self, word: str, stream_end: bool = False) -> None:
|
|
185
|
+
"""Callback function for handling the generated text."""
|
|
186
|
+
self._call_counter += 1
|
|
187
|
+
word_to_send = word + "\n" if stream_end else word
|
|
188
|
+
if word_to_send.strip() not in self.stop_words:
|
|
189
|
+
self.token_handler(
|
|
190
|
+
StreamingChunk(
|
|
191
|
+
content=word_to_send, index=0, start=self._call_counter == 1, component_info=self.component_info
|
|
192
|
+
)
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class _AsyncHFTokenStreamingHandler(TextStreamer):
|
|
197
|
+
"""
|
|
198
|
+
Async streaming handler for TransformersChatGenerator.
|
|
199
|
+
|
|
200
|
+
Note: This is a helper class for TransformersChatGenerator enabling
|
|
201
|
+
async streaming of generated text via Haystack Callable[StreamingChunk, Awaitable[None]] callbacks.
|
|
202
|
+
|
|
203
|
+
Do not use this class directly.
|
|
204
|
+
"""
|
|
205
|
+
|
|
206
|
+
def __init__(
|
|
207
|
+
self,
|
|
208
|
+
tokenizer: PreTrainedTokenizerBase,
|
|
209
|
+
stream_handler: AsyncStreamingCallbackT,
|
|
210
|
+
stop_words: list[str] | None = None,
|
|
211
|
+
component_info: ComponentInfo | None = None,
|
|
212
|
+
) -> None:
|
|
213
|
+
"""Creates an instance of _AsyncHFTokenStreamingHandler."""
|
|
214
|
+
super().__init__(tokenizer=tokenizer, skip_prompt=True)
|
|
215
|
+
self.token_handler = stream_handler
|
|
216
|
+
self.stop_words = stop_words or []
|
|
217
|
+
self.component_info = component_info
|
|
218
|
+
self._queue: asyncio.Queue[StreamingChunk] = asyncio.Queue()
|
|
219
|
+
|
|
220
|
+
def on_finalized_text(self, word: str, stream_end: bool = False) -> None:
|
|
221
|
+
"""Synchronous callback that puts chunks in a queue."""
|
|
222
|
+
word_to_send = word + "\n" if stream_end else word
|
|
223
|
+
if word_to_send.strip() not in self.stop_words:
|
|
224
|
+
self._queue.put_nowait(StreamingChunk(content=word_to_send, component_info=self.component_info))
|
|
225
|
+
|
|
226
|
+
async def process_queue(self) -> None:
|
|
227
|
+
"""Process the queue of streaming chunks."""
|
|
228
|
+
while True:
|
|
229
|
+
try:
|
|
230
|
+
chunk = await self._queue.get()
|
|
231
|
+
await self.token_handler(chunk)
|
|
232
|
+
self._queue.task_done()
|
|
233
|
+
except asyncio.CancelledError:
|
|
234
|
+
break
|
|
File without changes
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
from .named_entity_extractor import NamedEntityAnnotation, TransformersNamedEntityExtractor
|
|
5
|
+
|
|
6
|
+
__all__ = ["NamedEntityAnnotation", "TransformersNamedEntityExtractor"]
|