transformers-haystack 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- haystack_integrations/components/classifiers/py.typed +0 -0
- haystack_integrations/components/classifiers/transformers/__init__.py +6 -0
- haystack_integrations/components/classifiers/transformers/zero_shot_document_classifier.py +247 -0
- haystack_integrations/components/common/py.typed +0 -0
- haystack_integrations/components/common/transformers/__init__.py +3 -0
- haystack_integrations/components/common/transformers/utils.py +234 -0
- haystack_integrations/components/extractors/py.typed +0 -0
- haystack_integrations/components/extractors/transformers/__init__.py +6 -0
- haystack_integrations/components/extractors/transformers/named_entity_extractor.py +262 -0
- haystack_integrations/components/generators/py.typed +0 -0
- haystack_integrations/components/generators/transformers/__init__.py +6 -0
- haystack_integrations/components/generators/transformers/chat/__init__.py +3 -0
- haystack_integrations/components/generators/transformers/chat/chat_generator.py +666 -0
- haystack_integrations/components/readers/py.typed +0 -0
- haystack_integrations/components/readers/transformers/__init__.py +6 -0
- haystack_integrations/components/readers/transformers/extractive_reader.py +662 -0
- haystack_integrations/components/routers/py.typed +0 -0
- haystack_integrations/components/routers/transformers/__init__.py +7 -0
- haystack_integrations/components/routers/transformers/text_router.py +196 -0
- haystack_integrations/components/routers/transformers/zero_shot_text_router.py +205 -0
- transformers_haystack-0.1.0.dist-info/METADATA +38 -0
- transformers_haystack-0.1.0.dist-info/RECORD +24 -0
- transformers_haystack-0.1.0.dist-info/WHEEL +4 -0
- transformers_haystack-0.1.0.dist-info/licenses/LICENSE.txt +201 -0
|
@@ -0,0 +1,662 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
|
|
5
|
+
import math
|
|
6
|
+
from dataclasses import replace
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import accelerate # noqa: F401 # the library is used but not directly referenced
|
|
11
|
+
import torch
|
|
12
|
+
from haystack import Document, ExtractedAnswer, component, default_from_dict, default_to_dict, logging
|
|
13
|
+
from haystack.utils import ComponentDevice, Device, DeviceMap, Secret
|
|
14
|
+
from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
|
|
15
|
+
from tokenizers import Encoding
|
|
16
|
+
|
|
17
|
+
from haystack_integrations.components.common.transformers.utils import _resolve_hf_device_map
|
|
18
|
+
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@component
|
|
24
|
+
class TransformersExtractiveReader:
|
|
25
|
+
"""
|
|
26
|
+
Locates and extracts answers to a given query from Documents.
|
|
27
|
+
|
|
28
|
+
The TransformersExtractiveReader component performs extractive question answering.
|
|
29
|
+
It assigns a score to every possible answer span independently of other answer spans.
|
|
30
|
+
This fixes a common issue of other implementations which make comparisons across documents harder by normalizing
|
|
31
|
+
each document's answers independently.
|
|
32
|
+
|
|
33
|
+
Example usage:
|
|
34
|
+
```python
|
|
35
|
+
from haystack import Document
|
|
36
|
+
|
|
37
|
+
from haystack_integrations.components.readers.transformers import TransformersExtractiveReader
|
|
38
|
+
|
|
39
|
+
docs = [
|
|
40
|
+
Document(content="Python is a popular programming language"),
|
|
41
|
+
Document(content="python ist eine beliebte Programmiersprache"),
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
reader = TransformersExtractiveReader()
|
|
45
|
+
|
|
46
|
+
question = "What is a popular programming language?"
|
|
47
|
+
result = reader.run(query=question, documents=docs)
|
|
48
|
+
assert "Python" in result["answers"][0].data
|
|
49
|
+
```
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
model: Path | str = "deepset/roberta-base-squad2-distilled",
|
|
55
|
+
device: ComponentDevice | None = None,
|
|
56
|
+
token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
|
|
57
|
+
top_k: int = 20,
|
|
58
|
+
score_threshold: float | None = None,
|
|
59
|
+
max_seq_length: int = 384,
|
|
60
|
+
stride: int = 128,
|
|
61
|
+
max_batch_size: int | None = None,
|
|
62
|
+
answers_per_seq: int | None = None,
|
|
63
|
+
no_answer: bool = True,
|
|
64
|
+
calibration_factor: float = 0.1,
|
|
65
|
+
overlap_threshold: float | None = 0.01,
|
|
66
|
+
model_kwargs: dict[str, Any] | None = None,
|
|
67
|
+
) -> None:
|
|
68
|
+
"""
|
|
69
|
+
Creates an instance of TransformersExtractiveReader.
|
|
70
|
+
|
|
71
|
+
:param model:
|
|
72
|
+
A Hugging Face transformers question answering model.
|
|
73
|
+
Can either be a path to a folder containing the model files or an identifier for the Hugging Face hub.
|
|
74
|
+
:param device:
|
|
75
|
+
The device on which the model is loaded. If `None`, the default device is automatically selected.
|
|
76
|
+
:param token:
|
|
77
|
+
The API token used to download private models from Hugging Face.
|
|
78
|
+
:param top_k:
|
|
79
|
+
Number of answers to return per query. It is required even if score_threshold is set.
|
|
80
|
+
An additional answer with no text is returned if no_answer is set to True (default).
|
|
81
|
+
:param score_threshold:
|
|
82
|
+
Returns only answers with the probability score above this threshold.
|
|
83
|
+
:param max_seq_length:
|
|
84
|
+
Maximum number of tokens. If a sequence exceeds it, the sequence is split.
|
|
85
|
+
:param stride:
|
|
86
|
+
Number of tokens that overlap when sequence is split because it exceeds max_seq_length.
|
|
87
|
+
:param max_batch_size:
|
|
88
|
+
Maximum number of samples that are fed through the model at the same time.
|
|
89
|
+
:param answers_per_seq:
|
|
90
|
+
Number of answer candidates to consider per sequence.
|
|
91
|
+
This is relevant when a Document was split into multiple sequences because of max_seq_length.
|
|
92
|
+
:param no_answer:
|
|
93
|
+
Whether to return an additional `no answer` with an empty text and a score representing the
|
|
94
|
+
probability that the other top_k answers are incorrect.
|
|
95
|
+
:param calibration_factor:
|
|
96
|
+
Factor used for calibrating probabilities.
|
|
97
|
+
:param overlap_threshold:
|
|
98
|
+
If set this will remove duplicate answers if they have an overlap larger than the
|
|
99
|
+
supplied threshold. For example, for the answers "in the river in Maine" and "the river" we would remove
|
|
100
|
+
one of these answers since the second answer has a 100% (1.0) overlap with the first answer.
|
|
101
|
+
However, for the answers "the river in" and "in Maine" there is only a max overlap percentage of 25% so
|
|
102
|
+
both of these answers could be kept if this variable is set to 0.24 or lower.
|
|
103
|
+
If None is provided then all answers are kept.
|
|
104
|
+
:param model_kwargs:
|
|
105
|
+
Additional keyword arguments passed to `AutoModelForQuestionAnswering.from_pretrained`
|
|
106
|
+
when loading the model specified in `model`. For details on what kwargs you can pass,
|
|
107
|
+
see the model's documentation.
|
|
108
|
+
"""
|
|
109
|
+
self.model_name_or_path = str(model)
|
|
110
|
+
self.model = None
|
|
111
|
+
self.tokenizer: Any = None
|
|
112
|
+
self.device: ComponentDevice | None = None
|
|
113
|
+
self.token = token
|
|
114
|
+
self.max_seq_length = max_seq_length
|
|
115
|
+
self.top_k = top_k
|
|
116
|
+
self.score_threshold = score_threshold
|
|
117
|
+
self.stride = stride
|
|
118
|
+
self.max_batch_size = max_batch_size
|
|
119
|
+
self.answers_per_seq = answers_per_seq
|
|
120
|
+
self.no_answer = no_answer
|
|
121
|
+
self.calibration_factor = calibration_factor
|
|
122
|
+
self.overlap_threshold = overlap_threshold
|
|
123
|
+
|
|
124
|
+
model_kwargs = _resolve_hf_device_map(device=device, model_kwargs=model_kwargs)
|
|
125
|
+
self.model_kwargs = model_kwargs
|
|
126
|
+
|
|
127
|
+
def _get_telemetry_data(self) -> dict[str, Any]:
|
|
128
|
+
"""
|
|
129
|
+
Data that is sent to Posthog for usage analytics.
|
|
130
|
+
"""
|
|
131
|
+
return {"model": self.model_name_or_path}
|
|
132
|
+
|
|
133
|
+
def to_dict(self) -> dict[str, Any]:
|
|
134
|
+
"""
|
|
135
|
+
Serializes the component to a dictionary.
|
|
136
|
+
|
|
137
|
+
:returns:
|
|
138
|
+
Dictionary with serialized data.
|
|
139
|
+
"""
|
|
140
|
+
serialization_dict = default_to_dict(
|
|
141
|
+
self,
|
|
142
|
+
model=self.model_name_or_path,
|
|
143
|
+
device=None,
|
|
144
|
+
token=self.token,
|
|
145
|
+
max_seq_length=self.max_seq_length,
|
|
146
|
+
top_k=self.top_k,
|
|
147
|
+
score_threshold=self.score_threshold,
|
|
148
|
+
stride=self.stride,
|
|
149
|
+
max_batch_size=self.max_batch_size,
|
|
150
|
+
answers_per_seq=self.answers_per_seq,
|
|
151
|
+
no_answer=self.no_answer,
|
|
152
|
+
calibration_factor=self.calibration_factor,
|
|
153
|
+
model_kwargs=self.model_kwargs,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
|
|
157
|
+
return serialization_dict
|
|
158
|
+
|
|
159
|
+
@classmethod
|
|
160
|
+
def from_dict(cls, data: dict[str, Any]) -> "TransformersExtractiveReader":
|
|
161
|
+
"""
|
|
162
|
+
Deserializes the component from a dictionary.
|
|
163
|
+
|
|
164
|
+
:param data:
|
|
165
|
+
Dictionary to deserialize from.
|
|
166
|
+
:returns:
|
|
167
|
+
Deserialized component.
|
|
168
|
+
"""
|
|
169
|
+
init_params = data["init_parameters"]
|
|
170
|
+
if init_params.get("model_kwargs") is not None:
|
|
171
|
+
deserialize_hf_model_kwargs(init_params["model_kwargs"])
|
|
172
|
+
|
|
173
|
+
return default_from_dict(cls, data)
|
|
174
|
+
|
|
175
|
+
def warm_up(self) -> None:
|
|
176
|
+
"""
|
|
177
|
+
Initializes the component.
|
|
178
|
+
"""
|
|
179
|
+
# Take the first device used by `accelerate`. Needed to pass inputs from the tokenizer to the correct device.
|
|
180
|
+
if self.model is None:
|
|
181
|
+
self.model = AutoModelForQuestionAnswering.from_pretrained(
|
|
182
|
+
self.model_name_or_path, token=self.token.resolve_value() if self.token else None, **self.model_kwargs
|
|
183
|
+
)
|
|
184
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
185
|
+
self.model_name_or_path, token=self.token.resolve_value() if self.token else None
|
|
186
|
+
)
|
|
187
|
+
assert self.model is not None # noqa: S101 # mypy doesn't know this is set in the line above
|
|
188
|
+
# hf_device_map appears to only be set now when mixed devices are actually used.
|
|
189
|
+
# So if it's missing then we can use the device attribute which is set even for single-device models.
|
|
190
|
+
if hf_device_map := getattr(self.model, "hf_device_map", None):
|
|
191
|
+
self.device = ComponentDevice.from_multiple(device_map=DeviceMap.from_hf(hf_device_map))
|
|
192
|
+
else:
|
|
193
|
+
self.device = ComponentDevice.from_single(Device.from_str(str(self.model.device)))
|
|
194
|
+
|
|
195
|
+
@staticmethod
|
|
196
|
+
def _flatten_documents(
|
|
197
|
+
queries: list[str], documents: list[list[Document]]
|
|
198
|
+
) -> tuple[list[str], list[Document], list[int]]:
|
|
199
|
+
"""
|
|
200
|
+
Flattens queries and Documents so all query-document pairs are arranged along one batch axis.
|
|
201
|
+
"""
|
|
202
|
+
flattened_queries = [query for documents_, query in zip(documents, queries, strict=True) for _ in documents_]
|
|
203
|
+
flattened_documents = [document for documents_ in documents for document in documents_]
|
|
204
|
+
query_ids = [i for i, documents_ in enumerate(documents) for _ in documents_]
|
|
205
|
+
return flattened_queries, flattened_documents, query_ids
|
|
206
|
+
|
|
207
|
+
def _preprocess(
|
|
208
|
+
self, *, queries: list[str], documents: list[Document], max_seq_length: int, query_ids: list[int], stride: int
|
|
209
|
+
) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", list["Encoding"], list[int], list[int]]:
|
|
210
|
+
"""
|
|
211
|
+
Splits and tokenizes Documents and preserves structures by returning mappings to query and Document IDs.
|
|
212
|
+
"""
|
|
213
|
+
texts = []
|
|
214
|
+
document_ids = []
|
|
215
|
+
document_contents = []
|
|
216
|
+
for i, doc in enumerate(documents):
|
|
217
|
+
if doc.content is None:
|
|
218
|
+
logger.warning(
|
|
219
|
+
"Document with id {doc_id} was passed to TransformersExtractiveReader. The Document doesn't "
|
|
220
|
+
"contain any text and it will be ignored.",
|
|
221
|
+
doc_id=doc.id,
|
|
222
|
+
)
|
|
223
|
+
continue
|
|
224
|
+
texts.append(doc.content)
|
|
225
|
+
document_ids.append(i)
|
|
226
|
+
document_contents.append(doc.content)
|
|
227
|
+
|
|
228
|
+
# mypy doesn't know this is set in warm_up
|
|
229
|
+
encodings_pt = self.tokenizer(
|
|
230
|
+
queries,
|
|
231
|
+
document_contents,
|
|
232
|
+
padding=True,
|
|
233
|
+
truncation=True,
|
|
234
|
+
max_length=max_seq_length,
|
|
235
|
+
return_tensors="pt",
|
|
236
|
+
return_overflowing_tokens=True,
|
|
237
|
+
stride=stride,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Take the first device used by `accelerate`. Needed to pass inputs from the tokenizer to the correct device.
|
|
241
|
+
# mypy doesn't know this is set in warm_up
|
|
242
|
+
first_device = self.device.first_device.to_torch() # type: ignore[union-attr]
|
|
243
|
+
|
|
244
|
+
input_ids = encodings_pt.input_ids.to(first_device)
|
|
245
|
+
attention_mask = encodings_pt.attention_mask.to(first_device)
|
|
246
|
+
|
|
247
|
+
query_ids = [query_ids[index] for index in encodings_pt.overflow_to_sample_mapping]
|
|
248
|
+
document_ids = [document_ids[sample_id] for sample_id in encodings_pt.overflow_to_sample_mapping]
|
|
249
|
+
|
|
250
|
+
encodings = encodings_pt.encodings
|
|
251
|
+
sequence_ids = torch.tensor(
|
|
252
|
+
[[id_ if id_ is not None else -1 for id_ in encoding.sequence_ids] for encoding in encodings]
|
|
253
|
+
).to(first_device)
|
|
254
|
+
|
|
255
|
+
return input_ids, attention_mask, sequence_ids, encodings, query_ids, document_ids
|
|
256
|
+
|
|
257
|
+
def _postprocess(
|
|
258
|
+
self,
|
|
259
|
+
*,
|
|
260
|
+
start: "torch.Tensor",
|
|
261
|
+
end: "torch.Tensor",
|
|
262
|
+
sequence_ids: "torch.Tensor",
|
|
263
|
+
attention_mask: "torch.Tensor",
|
|
264
|
+
answers_per_seq: int,
|
|
265
|
+
encodings: list["Encoding"],
|
|
266
|
+
) -> tuple[list[list[int]], list[list[int]], list["torch.Tensor"]]:
|
|
267
|
+
"""
|
|
268
|
+
Turns start and end logits into probabilities for each answer span.
|
|
269
|
+
|
|
270
|
+
Unlike most other implementations, it doesn't normalize the scores in each split to make them easier to
|
|
271
|
+
compare across different splits. Returns the top k answer spans.
|
|
272
|
+
"""
|
|
273
|
+
mask = sequence_ids == 1 # Only keep tokens from the context (should ignore special tokens)
|
|
274
|
+
mask = torch.logical_and(mask, attention_mask == 1) # Definitely remove special tokens
|
|
275
|
+
start = torch.where(mask, start, -torch.inf) # Apply the mask on the start logits
|
|
276
|
+
end = torch.where(mask, end, -torch.inf) # Apply the mask on the end logits
|
|
277
|
+
start = start.unsqueeze(-1)
|
|
278
|
+
end = end.unsqueeze(-2)
|
|
279
|
+
|
|
280
|
+
logits = start + end # shape: (batch_size, seq_length (start), seq_length (end))
|
|
281
|
+
|
|
282
|
+
# The mask here onwards is the same for all instances in the batch
|
|
283
|
+
# As such we do away with the batch dimension
|
|
284
|
+
mask = torch.ones(logits.shape[-2:], dtype=torch.bool, device=logits.device)
|
|
285
|
+
mask = torch.triu(mask) # End shouldn't be before start
|
|
286
|
+
masked_logits = torch.where(mask, logits, -torch.inf)
|
|
287
|
+
probabilities = torch.sigmoid(masked_logits * self.calibration_factor)
|
|
288
|
+
|
|
289
|
+
flat_probabilities = probabilities.flatten(-2, -1) # necessary for top-k
|
|
290
|
+
|
|
291
|
+
# top-k can return invalid candidates as well if answers_per_seq > num_valid_candidates
|
|
292
|
+
# We only keep probability > 0 candidates later on
|
|
293
|
+
candidates = torch.topk(flat_probabilities, answers_per_seq)
|
|
294
|
+
seq_length = logits.shape[-1]
|
|
295
|
+
start_candidates = candidates.indices // seq_length # Recover indices from flattening
|
|
296
|
+
end_candidates = candidates.indices % seq_length
|
|
297
|
+
candidates_values = candidates.values.cpu()
|
|
298
|
+
start_candidates = start_candidates.cpu()
|
|
299
|
+
end_candidates = end_candidates.cpu()
|
|
300
|
+
|
|
301
|
+
start_candidates_tokens_to_chars = []
|
|
302
|
+
end_candidates_tokens_to_chars = []
|
|
303
|
+
valid_candidates_values: list[torch.Tensor] = []
|
|
304
|
+
for i, (s_candidates, e_candidates, encoding) in enumerate(
|
|
305
|
+
zip(start_candidates, end_candidates, encodings, strict=True)
|
|
306
|
+
):
|
|
307
|
+
# Those with probabilities > 0 are valid. topk may include masked candidates
|
|
308
|
+
# when answers_per_seq exceeds the number of valid spans, so filter all three lists together.
|
|
309
|
+
valid = candidates_values[i] > 0
|
|
310
|
+
s_char_spans = []
|
|
311
|
+
e_char_spans = []
|
|
312
|
+
for start_token, end_token in zip(s_candidates[valid], e_candidates[valid], strict=True):
|
|
313
|
+
# token_to_chars returns `None` for special tokens
|
|
314
|
+
# But we shouldn't have special tokens in the answers at this point
|
|
315
|
+
# The whole span is given by the start of the start_token (index 0)
|
|
316
|
+
# and the end of the end token (index 1)
|
|
317
|
+
s_char_spans.append(encoding.token_to_chars(start_token)[0])
|
|
318
|
+
e_char_spans.append(encoding.token_to_chars(end_token)[1])
|
|
319
|
+
start_candidates_tokens_to_chars.append(s_char_spans)
|
|
320
|
+
end_candidates_tokens_to_chars.append(e_char_spans)
|
|
321
|
+
valid_candidates_values.append(candidates_values[i][valid])
|
|
322
|
+
|
|
323
|
+
return start_candidates_tokens_to_chars, end_candidates_tokens_to_chars, valid_candidates_values
|
|
324
|
+
|
|
325
|
+
def _add_answer_page_number(self, answer: ExtractedAnswer) -> ExtractedAnswer:
|
|
326
|
+
if answer.meta is None:
|
|
327
|
+
answer = replace(answer, meta={})
|
|
328
|
+
|
|
329
|
+
if answer.document_offset is None:
|
|
330
|
+
return answer
|
|
331
|
+
|
|
332
|
+
if not answer.document or "page_number" not in answer.document.meta:
|
|
333
|
+
return answer
|
|
334
|
+
|
|
335
|
+
if not isinstance(answer.document.meta["page_number"], int):
|
|
336
|
+
logger.warning(
|
|
337
|
+
"Document's page_number must be int but is {type}. No page number will be added to the answer.",
|
|
338
|
+
type=type(answer.document.meta["page_number"]),
|
|
339
|
+
)
|
|
340
|
+
return answer
|
|
341
|
+
|
|
342
|
+
# Calculate the answer page number
|
|
343
|
+
if answer.document.content:
|
|
344
|
+
ans_start = answer.document_offset.start
|
|
345
|
+
answer_page_number = answer.document.meta["page_number"] + answer.document.content[:ans_start].count("\f")
|
|
346
|
+
answer.meta.update({"answer_page_number": answer_page_number})
|
|
347
|
+
|
|
348
|
+
return answer
|
|
349
|
+
|
|
350
|
+
def _nest_answers(
|
|
351
|
+
self,
|
|
352
|
+
*,
|
|
353
|
+
start: list[list[int]],
|
|
354
|
+
end: list[list[int]],
|
|
355
|
+
probabilities: list["torch.Tensor"],
|
|
356
|
+
flattened_documents: list[Document],
|
|
357
|
+
queries: list[str],
|
|
358
|
+
answers_per_seq: int,
|
|
359
|
+
top_k: int | None,
|
|
360
|
+
score_threshold: float | None,
|
|
361
|
+
query_ids: list[int],
|
|
362
|
+
document_ids: list[int],
|
|
363
|
+
no_answer: bool,
|
|
364
|
+
overlap_threshold: float | None,
|
|
365
|
+
) -> list[list[ExtractedAnswer]]:
|
|
366
|
+
"""
|
|
367
|
+
Reconstructs the nested structure that existed before flattening.
|
|
368
|
+
|
|
369
|
+
Also computes a no answer score. This score is different from most other implementations because it does not
|
|
370
|
+
consider the no answer logit introduced with SQuAD 2. Instead, it just computes the probability that the
|
|
371
|
+
answer does not exist in the top k or top p.
|
|
372
|
+
"""
|
|
373
|
+
answers_without_query = []
|
|
374
|
+
for document_id, start_candidates_, end_candidates_, probabilities_ in zip(
|
|
375
|
+
document_ids, start, end, probabilities, strict=True
|
|
376
|
+
):
|
|
377
|
+
for start_, end_, probability in zip(start_candidates_, end_candidates_, probabilities_, strict=True):
|
|
378
|
+
doc = flattened_documents[document_id]
|
|
379
|
+
answers_without_query.append(
|
|
380
|
+
ExtractedAnswer(
|
|
381
|
+
query="", # Can't be None but we'll add it later
|
|
382
|
+
data=doc.content[start_:end_], # type: ignore
|
|
383
|
+
document=doc,
|
|
384
|
+
score=probability.item(),
|
|
385
|
+
document_offset=ExtractedAnswer.Span(start_, end_),
|
|
386
|
+
meta={},
|
|
387
|
+
)
|
|
388
|
+
)
|
|
389
|
+
i = 0
|
|
390
|
+
nested_answers = []
|
|
391
|
+
for query_id in range(query_ids[-1] + 1):
|
|
392
|
+
current_answers = []
|
|
393
|
+
# `i // answers_per_seq` assumes every sequence contributes exactly `answers_per_seq`
|
|
394
|
+
# answers. That's not guaranteed (see _postprocess: invalid candidates are
|
|
395
|
+
# filtered out per sequence), but is fine here because `run` always passes a single
|
|
396
|
+
# query, so every entry in `query_ids` is 0 and the index lookup is correct for any i.
|
|
397
|
+
while i < len(answers_without_query) and query_ids[i // answers_per_seq] == query_id:
|
|
398
|
+
current_answers.append(replace(answers_without_query[i], query=queries[query_id]))
|
|
399
|
+
i += 1
|
|
400
|
+
current_answers = sorted(current_answers, key=lambda ans: ans.score, reverse=True)
|
|
401
|
+
current_answers = self.deduplicate_by_overlap(current_answers, overlap_threshold=overlap_threshold)
|
|
402
|
+
current_answers = current_answers[:top_k]
|
|
403
|
+
|
|
404
|
+
# Calculate the answer page number and add it to meta
|
|
405
|
+
current_answers = [self._add_answer_page_number(answer=answer) for answer in current_answers]
|
|
406
|
+
|
|
407
|
+
if no_answer:
|
|
408
|
+
no_answer_score = math.prod(1 - answer.score for answer in current_answers)
|
|
409
|
+
answer_ = ExtractedAnswer(
|
|
410
|
+
data=None, query=queries[query_id], meta={}, document=None, score=no_answer_score
|
|
411
|
+
)
|
|
412
|
+
current_answers.append(answer_)
|
|
413
|
+
current_answers = sorted(current_answers, key=lambda ans: ans.score, reverse=True)
|
|
414
|
+
if score_threshold is not None:
|
|
415
|
+
current_answers = [answer for answer in current_answers if answer.score >= score_threshold]
|
|
416
|
+
nested_answers.append(current_answers)
|
|
417
|
+
|
|
418
|
+
return nested_answers
|
|
419
|
+
|
|
420
|
+
def _calculate_overlap(self, answer1_start: int, answer1_end: int, answer2_start: int, answer2_end: int) -> int:
|
|
421
|
+
"""
|
|
422
|
+
Calculates the amount of overlap (in number of characters) between two answer offsets.
|
|
423
|
+
|
|
424
|
+
This Stack overflow
|
|
425
|
+
[post](https://stackoverflow.com/questions/325933/determine-whether-two-date-ranges-overlap/325964#325964)
|
|
426
|
+
explains how to calculate the overlap between two ranges.
|
|
427
|
+
"""
|
|
428
|
+
# Check for overlap: (StartA <= EndB) and (StartB <= EndA)
|
|
429
|
+
if answer1_start <= answer2_end and answer2_start <= answer1_end:
|
|
430
|
+
return min(
|
|
431
|
+
answer1_end - answer1_start,
|
|
432
|
+
answer1_end - answer2_start,
|
|
433
|
+
answer2_end - answer1_start,
|
|
434
|
+
answer2_end - answer2_start,
|
|
435
|
+
)
|
|
436
|
+
return 0
|
|
437
|
+
|
|
438
|
+
def _should_keep(
|
|
439
|
+
self, candidate_answer: ExtractedAnswer, current_answers: list[ExtractedAnswer], overlap_threshold: float
|
|
440
|
+
) -> bool:
|
|
441
|
+
"""
|
|
442
|
+
Determines if the answer should be kept based on how much it overlaps with previous answers.
|
|
443
|
+
|
|
444
|
+
NOTE: We might want to avoid throwing away answers that only have a few character (or word) overlap:
|
|
445
|
+
- E.g. The answers "the river in" and "in Maine" from the context "I want to go to the river in Maine."
|
|
446
|
+
might both want to be kept.
|
|
447
|
+
|
|
448
|
+
:param candidate_answer:
|
|
449
|
+
Candidate answer that will be checked if it should be kept.
|
|
450
|
+
:param current_answers:
|
|
451
|
+
Current list of answers that will be kept.
|
|
452
|
+
:param overlap_threshold:
|
|
453
|
+
If the overlap between two answers is greater than this threshold then return False.
|
|
454
|
+
"""
|
|
455
|
+
keep = True
|
|
456
|
+
|
|
457
|
+
# If the candidate answer doesn't have a document keep it
|
|
458
|
+
if not candidate_answer.document:
|
|
459
|
+
return keep
|
|
460
|
+
|
|
461
|
+
for ans in current_answers:
|
|
462
|
+
# If an answer in current_answers doesn't have a document skip the comparison
|
|
463
|
+
if not ans.document:
|
|
464
|
+
continue
|
|
465
|
+
|
|
466
|
+
# If offset is missing then keep both
|
|
467
|
+
if ans.document_offset is None:
|
|
468
|
+
continue
|
|
469
|
+
|
|
470
|
+
# If offset is missing then keep both
|
|
471
|
+
if candidate_answer.document_offset is None:
|
|
472
|
+
continue
|
|
473
|
+
|
|
474
|
+
# If the answers come from different documents then keep both
|
|
475
|
+
if candidate_answer.document.id != ans.document.id:
|
|
476
|
+
continue
|
|
477
|
+
|
|
478
|
+
overlap_len = self._calculate_overlap(
|
|
479
|
+
answer1_start=ans.document_offset.start,
|
|
480
|
+
answer1_end=ans.document_offset.end,
|
|
481
|
+
answer2_start=candidate_answer.document_offset.start,
|
|
482
|
+
answer2_end=candidate_answer.document_offset.end,
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
# If overlap is 0 then keep
|
|
486
|
+
if overlap_len == 0:
|
|
487
|
+
continue
|
|
488
|
+
|
|
489
|
+
overlap_frac_answer1 = overlap_len / (ans.document_offset.end - ans.document_offset.start)
|
|
490
|
+
overlap_frac_answer2 = overlap_len / (
|
|
491
|
+
candidate_answer.document_offset.end - candidate_answer.document_offset.start
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
if overlap_frac_answer1 > overlap_threshold or overlap_frac_answer2 > overlap_threshold:
|
|
495
|
+
keep = False
|
|
496
|
+
break
|
|
497
|
+
|
|
498
|
+
return keep
|
|
499
|
+
|
|
500
|
+
def deduplicate_by_overlap(
|
|
501
|
+
self, answers: list[ExtractedAnswer], overlap_threshold: float | None
|
|
502
|
+
) -> list[ExtractedAnswer]:
|
|
503
|
+
"""
|
|
504
|
+
De-duplicates overlapping Extractive Answers.
|
|
505
|
+
|
|
506
|
+
De-duplicates overlapping Extractive Answers from the same document based on how much the spans of the
|
|
507
|
+
answers overlap.
|
|
508
|
+
|
|
509
|
+
:param answers:
|
|
510
|
+
List of answers to be deduplicated.
|
|
511
|
+
:param overlap_threshold:
|
|
512
|
+
If set this will remove duplicate answers if they have an overlap larger than the
|
|
513
|
+
supplied threshold. For example, for the answers "in the river in Maine" and "the river" we would remove
|
|
514
|
+
one of these answers since the second answer has a 100% (1.0) overlap with the first answer.
|
|
515
|
+
However, for the answers "the river in" and "in Maine" there is only a max overlap percentage of 25% so
|
|
516
|
+
both of these answers could be kept if this variable is set to 0.24 or lower.
|
|
517
|
+
If None is provided then all answers are kept.
|
|
518
|
+
:returns:
|
|
519
|
+
List of deduplicated answers.
|
|
520
|
+
"""
|
|
521
|
+
if overlap_threshold is None:
|
|
522
|
+
return answers
|
|
523
|
+
|
|
524
|
+
# Initialize with the first answer and its offsets_in_document
|
|
525
|
+
deduplicated_answers = [answers[0]]
|
|
526
|
+
|
|
527
|
+
# Loop over remaining answers to check for overlaps
|
|
528
|
+
for ans in answers[1:]:
|
|
529
|
+
keep = self._should_keep(
|
|
530
|
+
candidate_answer=ans, current_answers=deduplicated_answers, overlap_threshold=overlap_threshold
|
|
531
|
+
)
|
|
532
|
+
if keep:
|
|
533
|
+
deduplicated_answers.append(ans)
|
|
534
|
+
|
|
535
|
+
return deduplicated_answers
|
|
536
|
+
|
|
537
|
+
@component.output_types(answers=list[ExtractedAnswer])
|
|
538
|
+
def run(
|
|
539
|
+
self,
|
|
540
|
+
query: str,
|
|
541
|
+
documents: list[Document],
|
|
542
|
+
top_k: int | None = None,
|
|
543
|
+
score_threshold: float | None = None,
|
|
544
|
+
max_seq_length: int | None = None,
|
|
545
|
+
stride: int | None = None,
|
|
546
|
+
max_batch_size: int | None = None,
|
|
547
|
+
answers_per_seq: int | None = None,
|
|
548
|
+
no_answer: bool | None = None,
|
|
549
|
+
overlap_threshold: float | None = None,
|
|
550
|
+
) -> dict[str, Any]:
|
|
551
|
+
"""
|
|
552
|
+
Locates and extracts answers from the given Documents using the given query.
|
|
553
|
+
|
|
554
|
+
:param query:
|
|
555
|
+
Query string.
|
|
556
|
+
:param documents:
|
|
557
|
+
List of Documents in which you want to search for an answer to the query.
|
|
558
|
+
:param top_k:
|
|
559
|
+
The maximum number of answers to return.
|
|
560
|
+
An additional answer is returned if no_answer is set to True (default).
|
|
561
|
+
:param score_threshold:
|
|
562
|
+
Returns only answers with the score above this threshold.
|
|
563
|
+
:param max_seq_length:
|
|
564
|
+
Maximum number of tokens. If a sequence exceeds it, the sequence is split.
|
|
565
|
+
:param stride:
|
|
566
|
+
Number of tokens that overlap when sequence is split because it exceeds max_seq_length.
|
|
567
|
+
:param max_batch_size:
|
|
568
|
+
Maximum number of samples that are fed through the model at the same time.
|
|
569
|
+
:param answers_per_seq:
|
|
570
|
+
Number of answer candidates to consider per sequence.
|
|
571
|
+
This is relevant when a Document was split into multiple sequences because of max_seq_length.
|
|
572
|
+
:param no_answer:
|
|
573
|
+
Whether to return no answer scores.
|
|
574
|
+
:param overlap_threshold:
|
|
575
|
+
If set this will remove duplicate answers if they have an overlap larger than the
|
|
576
|
+
supplied threshold. For example, for the answers "in the river in Maine" and "the river" we would remove
|
|
577
|
+
one of these answers since the second answer has a 100% (1.0) overlap with the first answer.
|
|
578
|
+
However, for the answers "the river in" and "in Maine" there is only a max overlap percentage of 25% so
|
|
579
|
+
both of these answers could be kept if this variable is set to 0.24 or lower.
|
|
580
|
+
If None is provided then all answers are kept.
|
|
581
|
+
:returns:
|
|
582
|
+
List of answers sorted by (desc.) answer score.
|
|
583
|
+
"""
|
|
584
|
+
if self.model is None:
|
|
585
|
+
self.warm_up()
|
|
586
|
+
|
|
587
|
+
if not documents:
|
|
588
|
+
return {"answers": []}
|
|
589
|
+
|
|
590
|
+
queries = [query] # Temporary solution until we have decided what batching should look like in v2
|
|
591
|
+
nested_documents = [documents]
|
|
592
|
+
top_k = top_k or self.top_k
|
|
593
|
+
score_threshold = score_threshold or self.score_threshold
|
|
594
|
+
max_seq_length = max_seq_length or self.max_seq_length
|
|
595
|
+
stride = stride or self.stride
|
|
596
|
+
max_batch_size = max_batch_size or self.max_batch_size
|
|
597
|
+
answers_per_seq = answers_per_seq or self.answers_per_seq or 20
|
|
598
|
+
no_answer = no_answer if no_answer is not None else self.no_answer
|
|
599
|
+
overlap_threshold = overlap_threshold or self.overlap_threshold
|
|
600
|
+
|
|
601
|
+
flattened_queries, flattened_documents, query_ids = TransformersExtractiveReader._flatten_documents(
|
|
602
|
+
queries, nested_documents
|
|
603
|
+
)
|
|
604
|
+
input_ids, attention_mask, sequence_ids, encodings, query_ids, document_ids = self._preprocess(
|
|
605
|
+
queries=flattened_queries,
|
|
606
|
+
documents=flattened_documents,
|
|
607
|
+
max_seq_length=max_seq_length,
|
|
608
|
+
query_ids=query_ids,
|
|
609
|
+
stride=stride,
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
num_batches = math.ceil(input_ids.shape[0] / max_batch_size) if max_batch_size else 1
|
|
613
|
+
batch_size = max_batch_size or input_ids.shape[0]
|
|
614
|
+
|
|
615
|
+
start_logits_list = []
|
|
616
|
+
end_logits_list = []
|
|
617
|
+
|
|
618
|
+
for i in range(num_batches):
|
|
619
|
+
start_index = i * batch_size
|
|
620
|
+
end_index = start_index + batch_size
|
|
621
|
+
cur_input_ids = input_ids[start_index:end_index]
|
|
622
|
+
cur_attention_mask = attention_mask[start_index:end_index]
|
|
623
|
+
|
|
624
|
+
with torch.inference_mode():
|
|
625
|
+
# mypy doesn't know this is set in warm_up
|
|
626
|
+
output = self.model(input_ids=cur_input_ids, attention_mask=cur_attention_mask) # type: ignore[misc]
|
|
627
|
+
cur_start_logits = output.start_logits
|
|
628
|
+
cur_end_logits = output.end_logits
|
|
629
|
+
if num_batches != 1:
|
|
630
|
+
cur_start_logits = cur_start_logits.cpu()
|
|
631
|
+
cur_end_logits = cur_end_logits.cpu()
|
|
632
|
+
start_logits_list.append(cur_start_logits)
|
|
633
|
+
end_logits_list.append(cur_end_logits)
|
|
634
|
+
|
|
635
|
+
start_logits = torch.cat(start_logits_list)
|
|
636
|
+
end_logits = torch.cat(end_logits_list)
|
|
637
|
+
|
|
638
|
+
start, end, probabilities = self._postprocess(
|
|
639
|
+
start=start_logits,
|
|
640
|
+
end=end_logits,
|
|
641
|
+
sequence_ids=sequence_ids,
|
|
642
|
+
attention_mask=attention_mask,
|
|
643
|
+
answers_per_seq=answers_per_seq,
|
|
644
|
+
encodings=encodings,
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
answers = self._nest_answers(
|
|
648
|
+
start=start,
|
|
649
|
+
end=end,
|
|
650
|
+
probabilities=probabilities,
|
|
651
|
+
flattened_documents=flattened_documents,
|
|
652
|
+
queries=queries,
|
|
653
|
+
answers_per_seq=answers_per_seq,
|
|
654
|
+
top_k=top_k,
|
|
655
|
+
score_threshold=score_threshold,
|
|
656
|
+
query_ids=query_ids,
|
|
657
|
+
document_ids=document_ids,
|
|
658
|
+
no_answer=no_answer,
|
|
659
|
+
overlap_threshold=overlap_threshold,
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
return {"answers": answers[0]} # same temporary batching fix as above
|
|
File without changes
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
|
|
2
|
+
#
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
from .text_router import TransformersTextRouter
|
|
5
|
+
from .zero_shot_text_router import TransformersZeroShotTextRouter
|
|
6
|
+
|
|
7
|
+
__all__ = ["TransformersTextRouter", "TransformersZeroShotTextRouter"]
|