transformers-haystack 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. haystack_integrations/components/classifiers/py.typed +0 -0
  2. haystack_integrations/components/classifiers/transformers/__init__.py +6 -0
  3. haystack_integrations/components/classifiers/transformers/zero_shot_document_classifier.py +247 -0
  4. haystack_integrations/components/common/py.typed +0 -0
  5. haystack_integrations/components/common/transformers/__init__.py +3 -0
  6. haystack_integrations/components/common/transformers/utils.py +234 -0
  7. haystack_integrations/components/extractors/py.typed +0 -0
  8. haystack_integrations/components/extractors/transformers/__init__.py +6 -0
  9. haystack_integrations/components/extractors/transformers/named_entity_extractor.py +262 -0
  10. haystack_integrations/components/generators/py.typed +0 -0
  11. haystack_integrations/components/generators/transformers/__init__.py +6 -0
  12. haystack_integrations/components/generators/transformers/chat/__init__.py +3 -0
  13. haystack_integrations/components/generators/transformers/chat/chat_generator.py +666 -0
  14. haystack_integrations/components/readers/py.typed +0 -0
  15. haystack_integrations/components/readers/transformers/__init__.py +6 -0
  16. haystack_integrations/components/readers/transformers/extractive_reader.py +662 -0
  17. haystack_integrations/components/routers/py.typed +0 -0
  18. haystack_integrations/components/routers/transformers/__init__.py +7 -0
  19. haystack_integrations/components/routers/transformers/text_router.py +196 -0
  20. haystack_integrations/components/routers/transformers/zero_shot_text_router.py +205 -0
  21. transformers_haystack-0.1.0.dist-info/METADATA +38 -0
  22. transformers_haystack-0.1.0.dist-info/RECORD +24 -0
  23. transformers_haystack-0.1.0.dist-info/WHEEL +4 -0
  24. transformers_haystack-0.1.0.dist-info/licenses/LICENSE.txt +201 -0
@@ -0,0 +1,662 @@
1
+ # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+
5
+ import math
6
+ from dataclasses import replace
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import accelerate # noqa: F401 # the library is used but not directly referenced
11
+ import torch
12
+ from haystack import Document, ExtractedAnswer, component, default_from_dict, default_to_dict, logging
13
+ from haystack.utils import ComponentDevice, Device, DeviceMap, Secret
14
+ from haystack.utils.hf import deserialize_hf_model_kwargs, serialize_hf_model_kwargs
15
+ from tokenizers import Encoding
16
+
17
+ from haystack_integrations.components.common.transformers.utils import _resolve_hf_device_map
18
+ from transformers import AutoModelForQuestionAnswering, AutoTokenizer
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ @component
24
+ class TransformersExtractiveReader:
25
+ """
26
+ Locates and extracts answers to a given query from Documents.
27
+
28
+ The TransformersExtractiveReader component performs extractive question answering.
29
+ It assigns a score to every possible answer span independently of other answer spans.
30
+ This fixes a common issue of other implementations which make comparisons across documents harder by normalizing
31
+ each document's answers independently.
32
+
33
+ Example usage:
34
+ ```python
35
+ from haystack import Document
36
+
37
+ from haystack_integrations.components.readers.transformers import TransformersExtractiveReader
38
+
39
+ docs = [
40
+ Document(content="Python is a popular programming language"),
41
+ Document(content="python ist eine beliebte Programmiersprache"),
42
+ ]
43
+
44
+ reader = TransformersExtractiveReader()
45
+
46
+ question = "What is a popular programming language?"
47
+ result = reader.run(query=question, documents=docs)
48
+ assert "Python" in result["answers"][0].data
49
+ ```
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ model: Path | str = "deepset/roberta-base-squad2-distilled",
55
+ device: ComponentDevice | None = None,
56
+ token: Secret | None = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
57
+ top_k: int = 20,
58
+ score_threshold: float | None = None,
59
+ max_seq_length: int = 384,
60
+ stride: int = 128,
61
+ max_batch_size: int | None = None,
62
+ answers_per_seq: int | None = None,
63
+ no_answer: bool = True,
64
+ calibration_factor: float = 0.1,
65
+ overlap_threshold: float | None = 0.01,
66
+ model_kwargs: dict[str, Any] | None = None,
67
+ ) -> None:
68
+ """
69
+ Creates an instance of TransformersExtractiveReader.
70
+
71
+ :param model:
72
+ A Hugging Face transformers question answering model.
73
+ Can either be a path to a folder containing the model files or an identifier for the Hugging Face hub.
74
+ :param device:
75
+ The device on which the model is loaded. If `None`, the default device is automatically selected.
76
+ :param token:
77
+ The API token used to download private models from Hugging Face.
78
+ :param top_k:
79
+ Number of answers to return per query. It is required even if score_threshold is set.
80
+ An additional answer with no text is returned if no_answer is set to True (default).
81
+ :param score_threshold:
82
+ Returns only answers with the probability score above this threshold.
83
+ :param max_seq_length:
84
+ Maximum number of tokens. If a sequence exceeds it, the sequence is split.
85
+ :param stride:
86
+ Number of tokens that overlap when sequence is split because it exceeds max_seq_length.
87
+ :param max_batch_size:
88
+ Maximum number of samples that are fed through the model at the same time.
89
+ :param answers_per_seq:
90
+ Number of answer candidates to consider per sequence.
91
+ This is relevant when a Document was split into multiple sequences because of max_seq_length.
92
+ :param no_answer:
93
+ Whether to return an additional `no answer` with an empty text and a score representing the
94
+ probability that the other top_k answers are incorrect.
95
+ :param calibration_factor:
96
+ Factor used for calibrating probabilities.
97
+ :param overlap_threshold:
98
+ If set this will remove duplicate answers if they have an overlap larger than the
99
+ supplied threshold. For example, for the answers "in the river in Maine" and "the river" we would remove
100
+ one of these answers since the second answer has a 100% (1.0) overlap with the first answer.
101
+ However, for the answers "the river in" and "in Maine" there is only a max overlap percentage of 25% so
102
+ both of these answers could be kept if this variable is set to 0.24 or lower.
103
+ If None is provided then all answers are kept.
104
+ :param model_kwargs:
105
+ Additional keyword arguments passed to `AutoModelForQuestionAnswering.from_pretrained`
106
+ when loading the model specified in `model`. For details on what kwargs you can pass,
107
+ see the model's documentation.
108
+ """
109
+ self.model_name_or_path = str(model)
110
+ self.model = None
111
+ self.tokenizer: Any = None
112
+ self.device: ComponentDevice | None = None
113
+ self.token = token
114
+ self.max_seq_length = max_seq_length
115
+ self.top_k = top_k
116
+ self.score_threshold = score_threshold
117
+ self.stride = stride
118
+ self.max_batch_size = max_batch_size
119
+ self.answers_per_seq = answers_per_seq
120
+ self.no_answer = no_answer
121
+ self.calibration_factor = calibration_factor
122
+ self.overlap_threshold = overlap_threshold
123
+
124
+ model_kwargs = _resolve_hf_device_map(device=device, model_kwargs=model_kwargs)
125
+ self.model_kwargs = model_kwargs
126
+
127
+ def _get_telemetry_data(self) -> dict[str, Any]:
128
+ """
129
+ Data that is sent to Posthog for usage analytics.
130
+ """
131
+ return {"model": self.model_name_or_path}
132
+
133
+ def to_dict(self) -> dict[str, Any]:
134
+ """
135
+ Serializes the component to a dictionary.
136
+
137
+ :returns:
138
+ Dictionary with serialized data.
139
+ """
140
+ serialization_dict = default_to_dict(
141
+ self,
142
+ model=self.model_name_or_path,
143
+ device=None,
144
+ token=self.token,
145
+ max_seq_length=self.max_seq_length,
146
+ top_k=self.top_k,
147
+ score_threshold=self.score_threshold,
148
+ stride=self.stride,
149
+ max_batch_size=self.max_batch_size,
150
+ answers_per_seq=self.answers_per_seq,
151
+ no_answer=self.no_answer,
152
+ calibration_factor=self.calibration_factor,
153
+ model_kwargs=self.model_kwargs,
154
+ )
155
+
156
+ serialize_hf_model_kwargs(serialization_dict["init_parameters"]["model_kwargs"])
157
+ return serialization_dict
158
+
159
+ @classmethod
160
+ def from_dict(cls, data: dict[str, Any]) -> "TransformersExtractiveReader":
161
+ """
162
+ Deserializes the component from a dictionary.
163
+
164
+ :param data:
165
+ Dictionary to deserialize from.
166
+ :returns:
167
+ Deserialized component.
168
+ """
169
+ init_params = data["init_parameters"]
170
+ if init_params.get("model_kwargs") is not None:
171
+ deserialize_hf_model_kwargs(init_params["model_kwargs"])
172
+
173
+ return default_from_dict(cls, data)
174
+
175
+ def warm_up(self) -> None:
176
+ """
177
+ Initializes the component.
178
+ """
179
+ # Take the first device used by `accelerate`. Needed to pass inputs from the tokenizer to the correct device.
180
+ if self.model is None:
181
+ self.model = AutoModelForQuestionAnswering.from_pretrained(
182
+ self.model_name_or_path, token=self.token.resolve_value() if self.token else None, **self.model_kwargs
183
+ )
184
+ self.tokenizer = AutoTokenizer.from_pretrained(
185
+ self.model_name_or_path, token=self.token.resolve_value() if self.token else None
186
+ )
187
+ assert self.model is not None # noqa: S101 # mypy doesn't know this is set in the line above
188
+ # hf_device_map appears to only be set now when mixed devices are actually used.
189
+ # So if it's missing then we can use the device attribute which is set even for single-device models.
190
+ if hf_device_map := getattr(self.model, "hf_device_map", None):
191
+ self.device = ComponentDevice.from_multiple(device_map=DeviceMap.from_hf(hf_device_map))
192
+ else:
193
+ self.device = ComponentDevice.from_single(Device.from_str(str(self.model.device)))
194
+
195
+ @staticmethod
196
+ def _flatten_documents(
197
+ queries: list[str], documents: list[list[Document]]
198
+ ) -> tuple[list[str], list[Document], list[int]]:
199
+ """
200
+ Flattens queries and Documents so all query-document pairs are arranged along one batch axis.
201
+ """
202
+ flattened_queries = [query for documents_, query in zip(documents, queries, strict=True) for _ in documents_]
203
+ flattened_documents = [document for documents_ in documents for document in documents_]
204
+ query_ids = [i for i, documents_ in enumerate(documents) for _ in documents_]
205
+ return flattened_queries, flattened_documents, query_ids
206
+
207
+ def _preprocess(
208
+ self, *, queries: list[str], documents: list[Document], max_seq_length: int, query_ids: list[int], stride: int
209
+ ) -> tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", list["Encoding"], list[int], list[int]]:
210
+ """
211
+ Splits and tokenizes Documents and preserves structures by returning mappings to query and Document IDs.
212
+ """
213
+ texts = []
214
+ document_ids = []
215
+ document_contents = []
216
+ for i, doc in enumerate(documents):
217
+ if doc.content is None:
218
+ logger.warning(
219
+ "Document with id {doc_id} was passed to TransformersExtractiveReader. The Document doesn't "
220
+ "contain any text and it will be ignored.",
221
+ doc_id=doc.id,
222
+ )
223
+ continue
224
+ texts.append(doc.content)
225
+ document_ids.append(i)
226
+ document_contents.append(doc.content)
227
+
228
+ # mypy doesn't know this is set in warm_up
229
+ encodings_pt = self.tokenizer(
230
+ queries,
231
+ document_contents,
232
+ padding=True,
233
+ truncation=True,
234
+ max_length=max_seq_length,
235
+ return_tensors="pt",
236
+ return_overflowing_tokens=True,
237
+ stride=stride,
238
+ )
239
+
240
+ # Take the first device used by `accelerate`. Needed to pass inputs from the tokenizer to the correct device.
241
+ # mypy doesn't know this is set in warm_up
242
+ first_device = self.device.first_device.to_torch() # type: ignore[union-attr]
243
+
244
+ input_ids = encodings_pt.input_ids.to(first_device)
245
+ attention_mask = encodings_pt.attention_mask.to(first_device)
246
+
247
+ query_ids = [query_ids[index] for index in encodings_pt.overflow_to_sample_mapping]
248
+ document_ids = [document_ids[sample_id] for sample_id in encodings_pt.overflow_to_sample_mapping]
249
+
250
+ encodings = encodings_pt.encodings
251
+ sequence_ids = torch.tensor(
252
+ [[id_ if id_ is not None else -1 for id_ in encoding.sequence_ids] for encoding in encodings]
253
+ ).to(first_device)
254
+
255
+ return input_ids, attention_mask, sequence_ids, encodings, query_ids, document_ids
256
+
257
+ def _postprocess(
258
+ self,
259
+ *,
260
+ start: "torch.Tensor",
261
+ end: "torch.Tensor",
262
+ sequence_ids: "torch.Tensor",
263
+ attention_mask: "torch.Tensor",
264
+ answers_per_seq: int,
265
+ encodings: list["Encoding"],
266
+ ) -> tuple[list[list[int]], list[list[int]], list["torch.Tensor"]]:
267
+ """
268
+ Turns start and end logits into probabilities for each answer span.
269
+
270
+ Unlike most other implementations, it doesn't normalize the scores in each split to make them easier to
271
+ compare across different splits. Returns the top k answer spans.
272
+ """
273
+ mask = sequence_ids == 1 # Only keep tokens from the context (should ignore special tokens)
274
+ mask = torch.logical_and(mask, attention_mask == 1) # Definitely remove special tokens
275
+ start = torch.where(mask, start, -torch.inf) # Apply the mask on the start logits
276
+ end = torch.where(mask, end, -torch.inf) # Apply the mask on the end logits
277
+ start = start.unsqueeze(-1)
278
+ end = end.unsqueeze(-2)
279
+
280
+ logits = start + end # shape: (batch_size, seq_length (start), seq_length (end))
281
+
282
+ # The mask here onwards is the same for all instances in the batch
283
+ # As such we do away with the batch dimension
284
+ mask = torch.ones(logits.shape[-2:], dtype=torch.bool, device=logits.device)
285
+ mask = torch.triu(mask) # End shouldn't be before start
286
+ masked_logits = torch.where(mask, logits, -torch.inf)
287
+ probabilities = torch.sigmoid(masked_logits * self.calibration_factor)
288
+
289
+ flat_probabilities = probabilities.flatten(-2, -1) # necessary for top-k
290
+
291
+ # top-k can return invalid candidates as well if answers_per_seq > num_valid_candidates
292
+ # We only keep probability > 0 candidates later on
293
+ candidates = torch.topk(flat_probabilities, answers_per_seq)
294
+ seq_length = logits.shape[-1]
295
+ start_candidates = candidates.indices // seq_length # Recover indices from flattening
296
+ end_candidates = candidates.indices % seq_length
297
+ candidates_values = candidates.values.cpu()
298
+ start_candidates = start_candidates.cpu()
299
+ end_candidates = end_candidates.cpu()
300
+
301
+ start_candidates_tokens_to_chars = []
302
+ end_candidates_tokens_to_chars = []
303
+ valid_candidates_values: list[torch.Tensor] = []
304
+ for i, (s_candidates, e_candidates, encoding) in enumerate(
305
+ zip(start_candidates, end_candidates, encodings, strict=True)
306
+ ):
307
+ # Those with probabilities > 0 are valid. topk may include masked candidates
308
+ # when answers_per_seq exceeds the number of valid spans, so filter all three lists together.
309
+ valid = candidates_values[i] > 0
310
+ s_char_spans = []
311
+ e_char_spans = []
312
+ for start_token, end_token in zip(s_candidates[valid], e_candidates[valid], strict=True):
313
+ # token_to_chars returns `None` for special tokens
314
+ # But we shouldn't have special tokens in the answers at this point
315
+ # The whole span is given by the start of the start_token (index 0)
316
+ # and the end of the end token (index 1)
317
+ s_char_spans.append(encoding.token_to_chars(start_token)[0])
318
+ e_char_spans.append(encoding.token_to_chars(end_token)[1])
319
+ start_candidates_tokens_to_chars.append(s_char_spans)
320
+ end_candidates_tokens_to_chars.append(e_char_spans)
321
+ valid_candidates_values.append(candidates_values[i][valid])
322
+
323
+ return start_candidates_tokens_to_chars, end_candidates_tokens_to_chars, valid_candidates_values
324
+
325
+ def _add_answer_page_number(self, answer: ExtractedAnswer) -> ExtractedAnswer:
326
+ if answer.meta is None:
327
+ answer = replace(answer, meta={})
328
+
329
+ if answer.document_offset is None:
330
+ return answer
331
+
332
+ if not answer.document or "page_number" not in answer.document.meta:
333
+ return answer
334
+
335
+ if not isinstance(answer.document.meta["page_number"], int):
336
+ logger.warning(
337
+ "Document's page_number must be int but is {type}. No page number will be added to the answer.",
338
+ type=type(answer.document.meta["page_number"]),
339
+ )
340
+ return answer
341
+
342
+ # Calculate the answer page number
343
+ if answer.document.content:
344
+ ans_start = answer.document_offset.start
345
+ answer_page_number = answer.document.meta["page_number"] + answer.document.content[:ans_start].count("\f")
346
+ answer.meta.update({"answer_page_number": answer_page_number})
347
+
348
+ return answer
349
+
350
+ def _nest_answers(
351
+ self,
352
+ *,
353
+ start: list[list[int]],
354
+ end: list[list[int]],
355
+ probabilities: list["torch.Tensor"],
356
+ flattened_documents: list[Document],
357
+ queries: list[str],
358
+ answers_per_seq: int,
359
+ top_k: int | None,
360
+ score_threshold: float | None,
361
+ query_ids: list[int],
362
+ document_ids: list[int],
363
+ no_answer: bool,
364
+ overlap_threshold: float | None,
365
+ ) -> list[list[ExtractedAnswer]]:
366
+ """
367
+ Reconstructs the nested structure that existed before flattening.
368
+
369
+ Also computes a no answer score. This score is different from most other implementations because it does not
370
+ consider the no answer logit introduced with SQuAD 2. Instead, it just computes the probability that the
371
+ answer does not exist in the top k or top p.
372
+ """
373
+ answers_without_query = []
374
+ for document_id, start_candidates_, end_candidates_, probabilities_ in zip(
375
+ document_ids, start, end, probabilities, strict=True
376
+ ):
377
+ for start_, end_, probability in zip(start_candidates_, end_candidates_, probabilities_, strict=True):
378
+ doc = flattened_documents[document_id]
379
+ answers_without_query.append(
380
+ ExtractedAnswer(
381
+ query="", # Can't be None but we'll add it later
382
+ data=doc.content[start_:end_], # type: ignore
383
+ document=doc,
384
+ score=probability.item(),
385
+ document_offset=ExtractedAnswer.Span(start_, end_),
386
+ meta={},
387
+ )
388
+ )
389
+ i = 0
390
+ nested_answers = []
391
+ for query_id in range(query_ids[-1] + 1):
392
+ current_answers = []
393
+ # `i // answers_per_seq` assumes every sequence contributes exactly `answers_per_seq`
394
+ # answers. That's not guaranteed (see _postprocess: invalid candidates are
395
+ # filtered out per sequence), but is fine here because `run` always passes a single
396
+ # query, so every entry in `query_ids` is 0 and the index lookup is correct for any i.
397
+ while i < len(answers_without_query) and query_ids[i // answers_per_seq] == query_id:
398
+ current_answers.append(replace(answers_without_query[i], query=queries[query_id]))
399
+ i += 1
400
+ current_answers = sorted(current_answers, key=lambda ans: ans.score, reverse=True)
401
+ current_answers = self.deduplicate_by_overlap(current_answers, overlap_threshold=overlap_threshold)
402
+ current_answers = current_answers[:top_k]
403
+
404
+ # Calculate the answer page number and add it to meta
405
+ current_answers = [self._add_answer_page_number(answer=answer) for answer in current_answers]
406
+
407
+ if no_answer:
408
+ no_answer_score = math.prod(1 - answer.score for answer in current_answers)
409
+ answer_ = ExtractedAnswer(
410
+ data=None, query=queries[query_id], meta={}, document=None, score=no_answer_score
411
+ )
412
+ current_answers.append(answer_)
413
+ current_answers = sorted(current_answers, key=lambda ans: ans.score, reverse=True)
414
+ if score_threshold is not None:
415
+ current_answers = [answer for answer in current_answers if answer.score >= score_threshold]
416
+ nested_answers.append(current_answers)
417
+
418
+ return nested_answers
419
+
420
+ def _calculate_overlap(self, answer1_start: int, answer1_end: int, answer2_start: int, answer2_end: int) -> int:
421
+ """
422
+ Calculates the amount of overlap (in number of characters) between two answer offsets.
423
+
424
+ This Stack overflow
425
+ [post](https://stackoverflow.com/questions/325933/determine-whether-two-date-ranges-overlap/325964#325964)
426
+ explains how to calculate the overlap between two ranges.
427
+ """
428
+ # Check for overlap: (StartA <= EndB) and (StartB <= EndA)
429
+ if answer1_start <= answer2_end and answer2_start <= answer1_end:
430
+ return min(
431
+ answer1_end - answer1_start,
432
+ answer1_end - answer2_start,
433
+ answer2_end - answer1_start,
434
+ answer2_end - answer2_start,
435
+ )
436
+ return 0
437
+
438
+ def _should_keep(
439
+ self, candidate_answer: ExtractedAnswer, current_answers: list[ExtractedAnswer], overlap_threshold: float
440
+ ) -> bool:
441
+ """
442
+ Determines if the answer should be kept based on how much it overlaps with previous answers.
443
+
444
+ NOTE: We might want to avoid throwing away answers that only have a few character (or word) overlap:
445
+ - E.g. The answers "the river in" and "in Maine" from the context "I want to go to the river in Maine."
446
+ might both want to be kept.
447
+
448
+ :param candidate_answer:
449
+ Candidate answer that will be checked if it should be kept.
450
+ :param current_answers:
451
+ Current list of answers that will be kept.
452
+ :param overlap_threshold:
453
+ If the overlap between two answers is greater than this threshold then return False.
454
+ """
455
+ keep = True
456
+
457
+ # If the candidate answer doesn't have a document keep it
458
+ if not candidate_answer.document:
459
+ return keep
460
+
461
+ for ans in current_answers:
462
+ # If an answer in current_answers doesn't have a document skip the comparison
463
+ if not ans.document:
464
+ continue
465
+
466
+ # If offset is missing then keep both
467
+ if ans.document_offset is None:
468
+ continue
469
+
470
+ # If offset is missing then keep both
471
+ if candidate_answer.document_offset is None:
472
+ continue
473
+
474
+ # If the answers come from different documents then keep both
475
+ if candidate_answer.document.id != ans.document.id:
476
+ continue
477
+
478
+ overlap_len = self._calculate_overlap(
479
+ answer1_start=ans.document_offset.start,
480
+ answer1_end=ans.document_offset.end,
481
+ answer2_start=candidate_answer.document_offset.start,
482
+ answer2_end=candidate_answer.document_offset.end,
483
+ )
484
+
485
+ # If overlap is 0 then keep
486
+ if overlap_len == 0:
487
+ continue
488
+
489
+ overlap_frac_answer1 = overlap_len / (ans.document_offset.end - ans.document_offset.start)
490
+ overlap_frac_answer2 = overlap_len / (
491
+ candidate_answer.document_offset.end - candidate_answer.document_offset.start
492
+ )
493
+
494
+ if overlap_frac_answer1 > overlap_threshold or overlap_frac_answer2 > overlap_threshold:
495
+ keep = False
496
+ break
497
+
498
+ return keep
499
+
500
+ def deduplicate_by_overlap(
501
+ self, answers: list[ExtractedAnswer], overlap_threshold: float | None
502
+ ) -> list[ExtractedAnswer]:
503
+ """
504
+ De-duplicates overlapping Extractive Answers.
505
+
506
+ De-duplicates overlapping Extractive Answers from the same document based on how much the spans of the
507
+ answers overlap.
508
+
509
+ :param answers:
510
+ List of answers to be deduplicated.
511
+ :param overlap_threshold:
512
+ If set this will remove duplicate answers if they have an overlap larger than the
513
+ supplied threshold. For example, for the answers "in the river in Maine" and "the river" we would remove
514
+ one of these answers since the second answer has a 100% (1.0) overlap with the first answer.
515
+ However, for the answers "the river in" and "in Maine" there is only a max overlap percentage of 25% so
516
+ both of these answers could be kept if this variable is set to 0.24 or lower.
517
+ If None is provided then all answers are kept.
518
+ :returns:
519
+ List of deduplicated answers.
520
+ """
521
+ if overlap_threshold is None:
522
+ return answers
523
+
524
+ # Initialize with the first answer and its offsets_in_document
525
+ deduplicated_answers = [answers[0]]
526
+
527
+ # Loop over remaining answers to check for overlaps
528
+ for ans in answers[1:]:
529
+ keep = self._should_keep(
530
+ candidate_answer=ans, current_answers=deduplicated_answers, overlap_threshold=overlap_threshold
531
+ )
532
+ if keep:
533
+ deduplicated_answers.append(ans)
534
+
535
+ return deduplicated_answers
536
+
537
+ @component.output_types(answers=list[ExtractedAnswer])
538
+ def run(
539
+ self,
540
+ query: str,
541
+ documents: list[Document],
542
+ top_k: int | None = None,
543
+ score_threshold: float | None = None,
544
+ max_seq_length: int | None = None,
545
+ stride: int | None = None,
546
+ max_batch_size: int | None = None,
547
+ answers_per_seq: int | None = None,
548
+ no_answer: bool | None = None,
549
+ overlap_threshold: float | None = None,
550
+ ) -> dict[str, Any]:
551
+ """
552
+ Locates and extracts answers from the given Documents using the given query.
553
+
554
+ :param query:
555
+ Query string.
556
+ :param documents:
557
+ List of Documents in which you want to search for an answer to the query.
558
+ :param top_k:
559
+ The maximum number of answers to return.
560
+ An additional answer is returned if no_answer is set to True (default).
561
+ :param score_threshold:
562
+ Returns only answers with the score above this threshold.
563
+ :param max_seq_length:
564
+ Maximum number of tokens. If a sequence exceeds it, the sequence is split.
565
+ :param stride:
566
+ Number of tokens that overlap when sequence is split because it exceeds max_seq_length.
567
+ :param max_batch_size:
568
+ Maximum number of samples that are fed through the model at the same time.
569
+ :param answers_per_seq:
570
+ Number of answer candidates to consider per sequence.
571
+ This is relevant when a Document was split into multiple sequences because of max_seq_length.
572
+ :param no_answer:
573
+ Whether to return no answer scores.
574
+ :param overlap_threshold:
575
+ If set this will remove duplicate answers if they have an overlap larger than the
576
+ supplied threshold. For example, for the answers "in the river in Maine" and "the river" we would remove
577
+ one of these answers since the second answer has a 100% (1.0) overlap with the first answer.
578
+ However, for the answers "the river in" and "in Maine" there is only a max overlap percentage of 25% so
579
+ both of these answers could be kept if this variable is set to 0.24 or lower.
580
+ If None is provided then all answers are kept.
581
+ :returns:
582
+ List of answers sorted by (desc.) answer score.
583
+ """
584
+ if self.model is None:
585
+ self.warm_up()
586
+
587
+ if not documents:
588
+ return {"answers": []}
589
+
590
+ queries = [query] # Temporary solution until we have decided what batching should look like in v2
591
+ nested_documents = [documents]
592
+ top_k = top_k or self.top_k
593
+ score_threshold = score_threshold or self.score_threshold
594
+ max_seq_length = max_seq_length or self.max_seq_length
595
+ stride = stride or self.stride
596
+ max_batch_size = max_batch_size or self.max_batch_size
597
+ answers_per_seq = answers_per_seq or self.answers_per_seq or 20
598
+ no_answer = no_answer if no_answer is not None else self.no_answer
599
+ overlap_threshold = overlap_threshold or self.overlap_threshold
600
+
601
+ flattened_queries, flattened_documents, query_ids = TransformersExtractiveReader._flatten_documents(
602
+ queries, nested_documents
603
+ )
604
+ input_ids, attention_mask, sequence_ids, encodings, query_ids, document_ids = self._preprocess(
605
+ queries=flattened_queries,
606
+ documents=flattened_documents,
607
+ max_seq_length=max_seq_length,
608
+ query_ids=query_ids,
609
+ stride=stride,
610
+ )
611
+
612
+ num_batches = math.ceil(input_ids.shape[0] / max_batch_size) if max_batch_size else 1
613
+ batch_size = max_batch_size or input_ids.shape[0]
614
+
615
+ start_logits_list = []
616
+ end_logits_list = []
617
+
618
+ for i in range(num_batches):
619
+ start_index = i * batch_size
620
+ end_index = start_index + batch_size
621
+ cur_input_ids = input_ids[start_index:end_index]
622
+ cur_attention_mask = attention_mask[start_index:end_index]
623
+
624
+ with torch.inference_mode():
625
+ # mypy doesn't know this is set in warm_up
626
+ output = self.model(input_ids=cur_input_ids, attention_mask=cur_attention_mask) # type: ignore[misc]
627
+ cur_start_logits = output.start_logits
628
+ cur_end_logits = output.end_logits
629
+ if num_batches != 1:
630
+ cur_start_logits = cur_start_logits.cpu()
631
+ cur_end_logits = cur_end_logits.cpu()
632
+ start_logits_list.append(cur_start_logits)
633
+ end_logits_list.append(cur_end_logits)
634
+
635
+ start_logits = torch.cat(start_logits_list)
636
+ end_logits = torch.cat(end_logits_list)
637
+
638
+ start, end, probabilities = self._postprocess(
639
+ start=start_logits,
640
+ end=end_logits,
641
+ sequence_ids=sequence_ids,
642
+ attention_mask=attention_mask,
643
+ answers_per_seq=answers_per_seq,
644
+ encodings=encodings,
645
+ )
646
+
647
+ answers = self._nest_answers(
648
+ start=start,
649
+ end=end,
650
+ probabilities=probabilities,
651
+ flattened_documents=flattened_documents,
652
+ queries=queries,
653
+ answers_per_seq=answers_per_seq,
654
+ top_k=top_k,
655
+ score_threshold=score_threshold,
656
+ query_ids=query_ids,
657
+ document_ids=document_ids,
658
+ no_answer=no_answer,
659
+ overlap_threshold=overlap_threshold,
660
+ )
661
+
662
+ return {"answers": answers[0]} # same temporary batching fix as above
File without changes
@@ -0,0 +1,7 @@
1
+ # SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
+ #
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ from .text_router import TransformersTextRouter
5
+ from .zero_shot_text_router import TransformersZeroShotTextRouter
6
+
7
+ __all__ = ["TransformersTextRouter", "TransformersZeroShotTextRouter"]