nucliadb 6.5.0.post4426__py3-none-any.whl → 6.5.0.post4484__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -47,6 +47,8 @@ from nucliadb_models.labels import translate_alias_to_system_label
47
47
  from nucliadb_models.metadata import Extra, Origin
48
48
  from nucliadb_models.search import (
49
49
  SCORE_TYPE,
50
+ AugmentedContext,
51
+ AugmentedTextBlock,
50
52
  ConversationalStrategy,
51
53
  FieldExtensionStrategy,
52
54
  FindParagraph,
@@ -66,8 +68,10 @@ from nucliadb_models.search import (
66
68
  RagStrategy,
67
69
  RagStrategyName,
68
70
  TableImageStrategy,
71
+ TextBlockAugmentationType,
69
72
  )
70
73
  from nucliadb_protos import resources_pb2
74
+ from nucliadb_protos.resources_pb2 import ExtractedText, FieldComputedMetadata
71
75
  from nucliadb_utils.asyncio_utils import run_concurrently
72
76
  from nucliadb_utils.utilities import get_storage
73
77
 
@@ -89,10 +93,7 @@ class ParagraphIdNotFoundInExtractedMetadata(Exception):
89
93
  class CappedPromptContext:
90
94
  """
91
95
  Class to keep track of the size (in number of characters) of the prompt context
92
- and raise an exception if it exceeds the configured limit.
93
-
94
- This class will automatically trim data that exceeds the limit when it's being
95
- set on the dictionary.
96
+ and automatically trim data that exceeds the limit when it's being set on the dictionary.
96
97
  """
97
98
 
98
99
  def __init__(self, max_size: Optional[int]):
@@ -246,6 +247,7 @@ async def full_resource_prompt_context(
246
247
  resource: Optional[str],
247
248
  strategy: FullResourceStrategy,
248
249
  metrics: Metrics,
250
+ augmented_context: AugmentedContext,
249
251
  ) -> None:
250
252
  """
251
253
  Algorithm steps:
@@ -298,6 +300,12 @@ async def full_resource_prompt_context(
298
300
  del context[tb_id]
299
301
  # Add the extracted text of each field to the context.
300
302
  context[field.full()] = extracted_text
303
+ augmented_context.fields[field.full()] = AugmentedTextBlock(
304
+ id=field.full(),
305
+ text=extracted_text,
306
+ augmentation_type=TextBlockAugmentationType.FULL_RESOURCE,
307
+ )
308
+
301
309
  added_fields.add(field.full())
302
310
 
303
311
  metrics.set("full_resource_ops", len(added_fields))
@@ -314,6 +322,7 @@ async def extend_prompt_context_with_metadata(
314
322
  kbid: str,
315
323
  strategy: MetadataExtensionStrategy,
316
324
  metrics: Metrics,
325
+ augmented_context: AugmentedContext,
317
326
  ) -> None:
318
327
  text_block_ids: list[TextBlockId] = []
319
328
  for text_block_id in context.text_block_ids():
@@ -329,19 +338,23 @@ async def extend_prompt_context_with_metadata(
329
338
  ops = 0
330
339
  if MetadataExtensionType.ORIGIN in strategy.types:
331
340
  ops += 1
332
- await extend_prompt_context_with_origin_metadata(context, kbid, text_block_ids)
341
+ await extend_prompt_context_with_origin_metadata(
342
+ context, kbid, text_block_ids, augmented_context
343
+ )
333
344
 
334
345
  if MetadataExtensionType.CLASSIFICATION_LABELS in strategy.types:
335
346
  ops += 1
336
- await extend_prompt_context_with_classification_labels(context, kbid, text_block_ids)
347
+ await extend_prompt_context_with_classification_labels(
348
+ context, kbid, text_block_ids, augmented_context
349
+ )
337
350
 
338
351
  if MetadataExtensionType.NERS in strategy.types:
339
352
  ops += 1
340
- await extend_prompt_context_with_ner(context, kbid, text_block_ids)
353
+ await extend_prompt_context_with_ner(context, kbid, text_block_ids, augmented_context)
341
354
 
342
355
  if MetadataExtensionType.EXTRA_METADATA in strategy.types:
343
356
  ops += 1
344
- await extend_prompt_context_with_extra_metadata(context, kbid, text_block_ids)
357
+ await extend_prompt_context_with_extra_metadata(context, kbid, text_block_ids, augmented_context)
345
358
 
346
359
  metrics.set("metadata_extension_ops", ops * len(text_block_ids))
347
360
 
@@ -356,7 +369,9 @@ def parse_text_block_id(text_block_id: str) -> TextBlockId:
356
369
  return FieldId.from_string(text_block_id)
357
370
 
358
371
 
359
- async def extend_prompt_context_with_origin_metadata(context, kbid, text_block_ids: list[TextBlockId]):
372
+ async def extend_prompt_context_with_origin_metadata(
373
+ context, kbid, text_block_ids: list[TextBlockId], augmented_context: AugmentedContext
374
+ ):
360
375
  async def _get_origin(kbid: str, rid: str) -> tuple[str, Optional[Origin]]:
361
376
  origin = None
362
377
  resource = await cache.get_resource(kbid, rid)
@@ -372,11 +387,19 @@ async def extend_prompt_context_with_origin_metadata(context, kbid, text_block_i
372
387
  for tb_id in text_block_ids:
373
388
  origin = rid_to_origin.get(tb_id.rid)
374
389
  if origin is not None and tb_id.full() in context.output:
375
- context[tb_id.full()] += f"\n\nDOCUMENT METADATA AT ORIGIN:\n{to_yaml(origin)}"
390
+ text = context.output.pop(tb_id.full())
391
+ extended_text = text + f"\n\nDOCUMENT METADATA AT ORIGIN:\n{to_yaml(origin)}"
392
+ context[tb_id.full()] = extended_text
393
+ augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
394
+ id=tb_id.full(),
395
+ text=extended_text,
396
+ parent=tb_id.full(),
397
+ augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
398
+ )
376
399
 
377
400
 
378
401
  async def extend_prompt_context_with_classification_labels(
379
- context, kbid, text_block_ids: list[TextBlockId]
402
+ context, kbid, text_block_ids: list[TextBlockId], augmented_context: AugmentedContext
380
403
  ):
381
404
  async def _get_labels(kbid: str, _id: TextBlockId) -> tuple[TextBlockId, list[tuple[str, str]]]:
382
405
  fid = _id if isinstance(_id, FieldId) else _id.field_id
@@ -402,13 +425,25 @@ async def extend_prompt_context_with_classification_labels(
402
425
  for tb_id in text_block_ids:
403
426
  labels = tb_id_to_labels.get(tb_id)
404
427
  if labels is not None and tb_id.full() in context.output:
428
+ text = context.output.pop(tb_id.full())
429
+
405
430
  labels_text = "DOCUMENT CLASSIFICATION LABELS:"
406
431
  for labelset, label in labels:
407
432
  labels_text += f"\n - {label} ({labelset})"
408
- context[tb_id.full()] += "\n\n" + labels_text
433
+ extended_text = text + "\n\n" + labels_text
434
+
435
+ context[tb_id.full()] = extended_text
436
+ augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
437
+ id=tb_id.full(),
438
+ text=extended_text,
439
+ parent=tb_id.full(),
440
+ augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
441
+ )
409
442
 
410
443
 
411
- async def extend_prompt_context_with_ner(context, kbid, text_block_ids: list[TextBlockId]):
444
+ async def extend_prompt_context_with_ner(
445
+ context, kbid, text_block_ids: list[TextBlockId], augmented_context: AugmentedContext
446
+ ):
412
447
  async def _get_ners(kbid: str, _id: TextBlockId) -> tuple[TextBlockId, dict[str, set[str]]]:
413
448
  fid = _id if isinstance(_id, FieldId) else _id.field_id
414
449
  ners: dict[str, set[str]] = {}
@@ -435,15 +470,28 @@ async def extend_prompt_context_with_ner(context, kbid, text_block_ids: list[Tex
435
470
  for tb_id in text_block_ids:
436
471
  ners = tb_id_to_ners.get(tb_id)
437
472
  if ners is not None and tb_id.full() in context.output:
473
+ text = context.output.pop(tb_id.full())
474
+
438
475
  ners_text = "DOCUMENT NAMED ENTITIES (NERs):"
439
476
  for family, tokens in ners.items():
440
477
  ners_text += f"\n - {family}:"
441
478
  for token in sorted(list(tokens)):
442
479
  ners_text += f"\n - {token}"
443
- context[tb_id.full()] += "\n\n" + ners_text
480
+
481
+ extended_text = text + "\n\n" + ners_text
482
+
483
+ context[tb_id.full()] = extended_text
484
+ augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
485
+ id=tb_id.full(),
486
+ text=extended_text,
487
+ parent=tb_id.full(),
488
+ augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
489
+ )
444
490
 
445
491
 
446
- async def extend_prompt_context_with_extra_metadata(context, kbid, text_block_ids: list[TextBlockId]):
492
+ async def extend_prompt_context_with_extra_metadata(
493
+ context, kbid, text_block_ids: list[TextBlockId], augmented_context: AugmentedContext
494
+ ):
447
495
  async def _get_extra(kbid: str, rid: str) -> tuple[str, Optional[Extra]]:
448
496
  extra = None
449
497
  resource = await cache.get_resource(kbid, rid)
@@ -459,7 +507,15 @@ async def extend_prompt_context_with_extra_metadata(context, kbid, text_block_id
459
507
  for tb_id in text_block_ids:
460
508
  extra = rid_to_extra.get(tb_id.rid)
461
509
  if extra is not None and tb_id.full() in context.output:
462
- context[tb_id.full()] += f"\n\nDOCUMENT EXTRA METADATA:\n{to_yaml(extra)}"
510
+ text = context.output.pop(tb_id.full())
511
+ extended_text = text + f"\n\nDOCUMENT EXTRA METADATA:\n{to_yaml(extra)}"
512
+ context[tb_id.full()] = extended_text
513
+ augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
514
+ id=tb_id.full(),
515
+ text=extended_text,
516
+ parent=tb_id.full(),
517
+ augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
518
+ )
463
519
 
464
520
 
465
521
  def to_yaml(obj: BaseModel) -> str:
@@ -477,6 +533,7 @@ async def field_extension_prompt_context(
477
533
  ordered_paragraphs: list[FindParagraph],
478
534
  strategy: FieldExtensionStrategy,
479
535
  metrics: Metrics,
536
+ augmented_context: AugmentedContext,
480
537
  ) -> None:
481
538
  """
482
539
  Algorithm steps:
@@ -518,115 +575,25 @@ async def field_extension_prompt_context(
518
575
  if tb_id.startswith(field.full()):
519
576
  del context[tb_id]
520
577
  # Add the extracted text of each field to the beginning of the context.
521
- context[field.full()] = extracted_text
578
+ if field.full() not in context.output:
579
+ context[field.full()] = extracted_text
580
+ augmented_context.fields[field.full()] = AugmentedTextBlock(
581
+ id=field.full(),
582
+ text=extracted_text,
583
+ augmentation_type=TextBlockAugmentationType.FIELD_EXTENSION,
584
+ )
522
585
 
523
586
  # Add the extracted text of each paragraph to the end of the context.
524
587
  for paragraph in ordered_paragraphs:
525
- context[paragraph.id] = _clean_paragraph_text(paragraph)
526
-
527
-
528
- async def get_paragraph_text_with_neighbours(
529
- kbid: str,
530
- pid: ParagraphId,
531
- field_paragraphs: list[ParagraphId],
532
- before: int = 0,
533
- after: int = 0,
534
- ) -> tuple[ParagraphId, str]:
535
- """
536
- This function will get the paragraph text of the paragraph with the neighbouring paragraphs included.
537
- Parameters:
538
- kbid: The knowledge box id.
539
- pid: The matching paragraph id.
540
- field_paragraphs: The list of paragraph ids of the field.
541
- before: The number of paragraphs to include before the matching paragraph.
542
- after: The number of paragraphs to include after the matching paragraph.
543
- """
544
-
545
- async def _get_paragraph_text(
546
- kbid: str,
547
- pid: ParagraphId,
548
- ) -> tuple[ParagraphId, str]:
549
- return pid, await get_paragraph_text(
550
- kbid=kbid,
551
- paragraph_id=pid,
552
- log_on_missing_field=True,
553
- )
554
-
555
- ops = []
556
- try:
557
- for paragraph_index in get_neighbouring_paragraph_indexes(
558
- field_paragraphs=field_paragraphs,
559
- matching_paragraph=pid,
560
- before=before,
561
- after=after,
562
- ):
563
- neighbour_pid = field_paragraphs[paragraph_index]
564
- ops.append(
565
- asyncio.create_task(
566
- _get_paragraph_text(
567
- kbid=kbid,
568
- pid=neighbour_pid,
569
- )
570
- )
571
- )
572
- except ParagraphIdNotFoundInExtractedMetadata:
573
- logger.warning(
574
- "Could not find matching paragraph in extracted metadata. This is odd and needs to be investigated.",
575
- extra={
576
- "kbid": kbid,
577
- "matching_paragraph": pid.full(),
578
- "field_paragraphs": [p.full() for p in field_paragraphs],
579
- },
580
- )
581
- # If we could not find the matching paragraph in the extracted metadata, we can't retrieve
582
- # the neighbouring paragraphs and we simply fetch the text of the matching paragraph.
583
- ops.append(
584
- asyncio.create_task(
585
- _get_paragraph_text(
586
- kbid=kbid,
587
- pid=pid,
588
- )
589
- )
590
- )
591
-
592
- results = []
593
- if len(ops) > 0:
594
- results = await asyncio.gather(*ops)
595
-
596
- # Sort the results by the paragraph start
597
- results.sort(key=lambda x: x[0].paragraph_start)
598
- paragraph_texts = []
599
- for _, text in results:
600
- if text != "":
601
- paragraph_texts.append(text)
602
- return pid, "\n\n".join(paragraph_texts)
588
+ if paragraph.id not in context.output:
589
+ context[paragraph.id] = _clean_paragraph_text(paragraph)
603
590
 
604
591
 
605
- async def get_field_paragraphs_list(
606
- kbid: str,
607
- field: FieldId,
608
- paragraphs: list[ParagraphId],
609
- ) -> None:
610
- """
611
- Modifies the paragraphs list by adding the paragraph ids of the field, sorted by position.
612
- """
613
- resource = await cache.get_resource(kbid, field.rid)
592
+ async def get_orm_field(kbid: str, field_id: FieldId) -> Optional[Field]:
593
+ resource = await cache.get_resource(kbid, field_id.rid)
614
594
  if resource is None: # pragma: no cover
615
- return
616
- field_obj: Field = await resource.get_field(key=field.key, type=field.pb_type, load=False)
617
- field_metadata: Optional[resources_pb2.FieldComputedMetadata] = await field_obj.get_field_metadata(
618
- force=True
619
- )
620
- if field_metadata is None: # pragma: no cover
621
- return
622
- for paragraph in field_metadata.metadata.paragraphs:
623
- paragraphs.append(
624
- ParagraphId(
625
- field_id=field,
626
- paragraph_start=paragraph.start,
627
- paragraph_end=paragraph.end,
628
- )
629
- )
595
+ return None
596
+ return await resource.get_field(key=field_id.key, type=field_id.pb_type, load=False)
630
597
 
631
598
 
632
599
  async def neighbouring_paragraphs_prompt_context(
@@ -635,61 +602,114 @@ async def neighbouring_paragraphs_prompt_context(
635
602
  ordered_text_blocks: list[FindParagraph],
636
603
  strategy: NeighbouringParagraphsStrategy,
637
604
  metrics: Metrics,
605
+ augmented_context: AugmentedContext,
638
606
  ) -> None:
639
607
  """
640
608
  This function will get the paragraph texts and then craft a context with the neighbouring paragraphs of the
641
- paragraphs in the ordered_paragraphs list. The number of paragraphs to include before and after each paragraph
609
+ paragraphs in the ordered_paragraphs list.
642
610
  """
643
- # First, get the sorted list of paragraphs for each matching field
644
- # so we can know the indexes of the neighbouring paragraphs
645
- unique_fields = {
646
- ParagraphId.from_string(text_block.id).field_id for text_block in ordered_text_blocks
611
+ retrieved_paragraphs_ids = [
612
+ ParagraphId.from_string(text_block.id) for text_block in ordered_text_blocks
613
+ ]
614
+ unique_field_ids = list({pid.field_id for pid in retrieved_paragraphs_ids})
615
+
616
+ # Get extracted texts and metadatas for all fields
617
+ fm_ops = []
618
+ et_ops = []
619
+ for field_id in unique_field_ids:
620
+ field = await get_orm_field(kbid, field_id)
621
+ if field is None:
622
+ continue
623
+ fm_ops.append(asyncio.create_task(field.get_field_metadata()))
624
+ et_ops.append(asyncio.create_task(field.get_extracted_text()))
625
+
626
+ field_metadatas: dict[FieldId, FieldComputedMetadata] = {
627
+ fid: fm for fid, fm in zip(unique_field_ids, await asyncio.gather(*fm_ops)) if fm is not None
647
628
  }
648
- paragraphs_by_field: dict[FieldId, list[ParagraphId]] = {}
649
- field_ops = []
650
- for field_id in unique_fields:
651
- plist = paragraphs_by_field.setdefault(field_id, [])
652
- field_ops.append(
653
- asyncio.create_task(get_field_paragraphs_list(kbid=kbid, field=field_id, paragraphs=plist))
654
- )
655
- if field_ops:
656
- await asyncio.gather(*field_ops)
657
-
658
- # Now, get the paragraph texts with the neighbouring paragraphs
659
- paragraph_ops = []
660
- for text_block in ordered_text_blocks:
661
- pid = ParagraphId.from_string(text_block.id)
662
- paragraph_ops.append(
663
- asyncio.create_task(
664
- get_paragraph_text_with_neighbours(
665
- kbid=kbid,
666
- pid=pid,
667
- before=strategy.before,
668
- after=strategy.after,
669
- field_paragraphs=paragraphs_by_field.get(pid.field_id, []),
670
- )
629
+ extracted_texts: dict[FieldId, ExtractedText] = {
630
+ fid: et for fid, et in zip(unique_field_ids, await asyncio.gather(*et_ops)) if et is not None
631
+ }
632
+
633
+ def _get_paragraph_text(extracted_text: ExtractedText, pid: ParagraphId) -> str:
634
+ if pid.field_id.subfield_id:
635
+ text = extracted_text.split_text.get(pid.field_id.subfield_id) or ""
636
+ else:
637
+ text = extracted_text.text
638
+ return text[pid.paragraph_start : pid.paragraph_end]
639
+
640
+ for pid in retrieved_paragraphs_ids:
641
+ # Add the retrieved paragraph first
642
+ field_extracted_text = extracted_texts.get(pid.field_id, None)
643
+ if field_extracted_text is None:
644
+ continue
645
+ ptext = _get_paragraph_text(field_extracted_text, pid)
646
+ if ptext:
647
+ context[pid.full()] = ptext
648
+
649
+ # Now add the neighbouring paragraphs
650
+ field_extracted_metadata = field_metadatas.get(pid.field_id, None)
651
+ if field_extracted_metadata is None:
652
+ continue
653
+
654
+ field_pids = [
655
+ ParagraphId(
656
+ field_id=pid.field_id,
657
+ paragraph_start=p.start,
658
+ paragraph_end=p.end,
671
659
  )
672
- )
673
- if not paragraph_ops: # pragma: no cover
674
- return
660
+ for p in field_extracted_metadata.metadata.paragraphs
661
+ ]
662
+ try:
663
+ index = field_pids.index(pid)
664
+ except IndexError:
665
+ continue
675
666
 
676
- results: list[tuple[ParagraphId, str]] = await asyncio.gather(*paragraph_ops)
667
+ for neighbour_index in get_neighbouring_indices(
668
+ index=index,
669
+ before=strategy.before,
670
+ after=strategy.after,
671
+ field_pids=field_pids,
672
+ ):
673
+ if neighbour_index == index:
674
+ # Already handled above
675
+ continue
676
+ try:
677
+ npid = field_pids[neighbour_index]
678
+ except IndexError:
679
+ continue
680
+ if npid in retrieved_paragraphs_ids or npid.full() in context.output:
681
+ # Already added above
682
+ continue
683
+ ptext = _get_paragraph_text(field_extracted_text, npid)
684
+ if not ptext:
685
+ continue
686
+ context[npid.full()] = ptext
687
+ augmented_context.paragraphs[npid.full()] = AugmentedTextBlock(
688
+ id=npid.full(),
689
+ text=ptext,
690
+ parent=pid.full(),
691
+ augmentation_type=TextBlockAugmentationType.NEIGHBOURING_PARAGRAPHS,
692
+ )
677
693
 
678
- metrics.set("neighbouring_paragraphs_ops", len(results))
694
+ metrics.set("neighbouring_paragraphs_ops", len(augmented_context.paragraphs))
679
695
 
680
- # Add the paragraph texts to the context
681
- for pid, text in results:
682
- if text != "":
683
- context[pid.full()] = text
696
+
697
+ def get_neighbouring_indices(
698
+ index: int, before: int, after: int, field_pids: list[ParagraphId]
699
+ ) -> list[int]:
700
+ lb_index = max(0, index - before)
701
+ ub_index = min(len(field_pids), index + after + 1)
702
+ return list(range(lb_index, index)) + list(range(index + 1, ub_index))
684
703
 
685
704
 
686
705
  async def conversation_prompt_context(
687
706
  context: CappedPromptContext,
688
707
  kbid: str,
689
708
  ordered_paragraphs: list[FindParagraph],
690
- conversational_strategy: ConversationalStrategy,
709
+ strategy: ConversationalStrategy,
691
710
  visual_llm: bool,
692
711
  metrics: Metrics,
712
+ augmented_context: AugmentedContext,
693
713
  ):
694
714
  analyzed_fields: List[str] = []
695
715
  ops = 0
@@ -721,7 +741,7 @@ async def conversation_prompt_context(
721
741
  cmetadata = await field_obj.get_metadata()
722
742
 
723
743
  attachments: List[resources_pb2.FieldRef] = []
724
- if conversational_strategy.full:
744
+ if strategy.full:
725
745
  ops += 5
726
746
  extracted_text = await field_obj.get_extracted_text()
727
747
  for current_page in range(1, cmetadata.pages + 1):
@@ -734,8 +754,16 @@ async def conversation_prompt_context(
734
754
  else:
735
755
  text = message.content.text.strip()
736
756
  pid = f"{rid}/{field_type}/{field_id}/{ident}/0-{len(text) + 1}"
737
- context[pid] = text
738
757
  attachments.extend(message.content.attachments_fields)
758
+ if pid in context.output:
759
+ continue
760
+ context[pid] = text
761
+ augmented_context.paragraphs[pid] = AugmentedTextBlock(
762
+ id=pid,
763
+ text=text,
764
+ parent=paragraph.id,
765
+ augmentation_type=TextBlockAugmentationType.CONVERSATION,
766
+ )
739
767
  else:
740
768
  # Add first message
741
769
  extracted_text = await field_obj.get_extracted_text()
@@ -747,13 +775,19 @@ async def conversation_prompt_context(
747
775
  text = extracted_text.split_text.get(ident, message.content.text.strip())
748
776
  else:
749
777
  text = message.content.text.strip()
778
+ attachments.extend(message.content.attachments_fields)
750
779
  pid = f"{rid}/{field_type}/{field_id}/{ident}/0-{len(text) + 1}"
780
+ if pid in context.output:
781
+ continue
751
782
  context[pid] = text
752
- attachments.extend(message.content.attachments_fields)
783
+ augmented_context.paragraphs[pid] = AugmentedTextBlock(
784
+ id=pid,
785
+ text=text,
786
+ parent=paragraph.id,
787
+ augmentation_type=TextBlockAugmentationType.CONVERSATION,
788
+ )
753
789
 
754
- messages: Deque[resources_pb2.Message] = deque(
755
- maxlen=conversational_strategy.max_messages
756
- )
790
+ messages: Deque[resources_pb2.Message] = deque(maxlen=strategy.max_messages)
757
791
 
758
792
  pending = -1
759
793
  for page in range(1, cmetadata.pages + 1):
@@ -764,7 +798,7 @@ async def conversation_prompt_context(
764
798
  if pending > 0:
765
799
  pending -= 1
766
800
  if message.ident == mident:
767
- pending = (conversational_strategy.max_messages - 1) // 2
801
+ pending = (strategy.max_messages - 1) // 2
768
802
  if pending == 0:
769
803
  break
770
804
  if pending == 0:
@@ -773,11 +807,19 @@ async def conversation_prompt_context(
773
807
  for message in messages:
774
808
  ops += 1
775
809
  text = message.content.text.strip()
810
+ attachments.extend(message.content.attachments_fields)
776
811
  pid = f"{rid}/{field_type}/{field_id}/{message.ident}/0-{len(message.content.text) + 1}"
812
+ if pid in context.output:
813
+ continue
777
814
  context[pid] = text
778
- attachments.extend(message.content.attachments_fields)
815
+ augmented_context.paragraphs[pid] = AugmentedTextBlock(
816
+ id=pid,
817
+ text=text,
818
+ parent=paragraph.id,
819
+ augmentation_type=TextBlockAugmentationType.CONVERSATION,
820
+ )
779
821
 
780
- if conversational_strategy.attachments_text:
822
+ if strategy.attachments_text:
781
823
  # add on the context the images if vlm enabled
782
824
  for attachment in attachments:
783
825
  ops += 1
@@ -787,9 +829,18 @@ async def conversation_prompt_context(
787
829
  extracted_text = await field.get_extracted_text()
788
830
  if extracted_text is not None:
789
831
  pid = f"{rid}/{field_type}/{attachment.field_id}/0-{len(extracted_text.text) + 1}"
790
- context[pid] = f"Attachment {attachment.field_id}: {extracted_text.text}\n\n"
791
-
792
- if conversational_strategy.attachments_images and visual_llm:
832
+ if pid in context.output:
833
+ continue
834
+ text = f"Attachment {attachment.field_id}: {extracted_text.text}\n\n"
835
+ context[pid] = text
836
+ augmented_context.paragraphs[pid] = AugmentedTextBlock(
837
+ id=pid,
838
+ text=text,
839
+ parent=paragraph.id,
840
+ augmentation_type=TextBlockAugmentationType.CONVERSATION,
841
+ )
842
+
843
+ if strategy.attachments_images and visual_llm:
793
844
  for attachment in attachments:
794
845
  ops += 1
795
846
  file_field: File = await resource.get_field(
@@ -810,6 +861,7 @@ async def hierarchy_prompt_context(
810
861
  ordered_paragraphs: list[FindParagraph],
811
862
  strategy: HierarchyResourceStrategy,
812
863
  metrics: Metrics,
864
+ augmented_context: AugmentedContext,
813
865
  ) -> None:
814
866
  """
815
867
  This function will get the paragraph texts (possibly with extra characters, if extra_characters > 0) and then
@@ -870,6 +922,7 @@ async def hierarchy_prompt_context(
870
922
  resources[rid].paragraphs.append((paragraph, extended_paragraph_text))
871
923
 
872
924
  metrics.set("hierarchy_ops", len(resources))
925
+ augmented_paragraphs = set()
873
926
 
874
927
  # Modify the first paragraph of each resource to include the title and summary of the resource, as well as the
875
928
  # extended paragraph text of all the paragraphs in the resource.
@@ -889,13 +942,20 @@ async def hierarchy_prompt_context(
889
942
  if first_paragraph is not None:
890
943
  # The first paragraph is the only one holding the hierarchy information
891
944
  first_paragraph.text = f"DOCUMENT: {title_text} \n SUMMARY: {summary_text} \n RESOURCE CONTENT: {text_with_hierarchy}"
945
+ augmented_paragraphs.add(first_paragraph.id)
892
946
 
893
947
  # Now that the paragraphs have been modified, we can add them to the context
894
948
  for paragraph in ordered_paragraphs_copy:
895
949
  if paragraph.text == "":
896
950
  # Skip paragraphs that were cleared in the hierarchy expansion
897
951
  continue
898
- context[paragraph.id] = _clean_paragraph_text(paragraph)
952
+ paragraph_text = _clean_paragraph_text(paragraph)
953
+ context[paragraph.id] = paragraph_text
954
+ if paragraph.id in augmented_paragraphs:
955
+ field_id = ParagraphId.from_string(paragraph.id).field_id.full()
956
+ augmented_context.fields[field_id] = AugmentedTextBlock(
957
+ id=field_id, text=paragraph_text, augmentation_type=TextBlockAugmentationType.HIERARCHY
958
+ )
899
959
  return
900
960
 
901
961
 
@@ -927,6 +987,7 @@ class PromptContextBuilder:
927
987
  self.max_context_characters = max_context_characters
928
988
  self.visual_llm = visual_llm
929
989
  self.metrics = metrics
990
+ self.augmented_context = AugmentedContext(paragraphs={}, fields={})
930
991
 
931
992
  def prepend_user_context(self, context: CappedPromptContext):
932
993
  # Chat extra context passed by the user is the most important, therefore
@@ -938,17 +999,17 @@ class PromptContextBuilder:
938
999
 
939
1000
  async def build(
940
1001
  self,
941
- ) -> tuple[PromptContext, PromptContextOrder, PromptContextImages]:
1002
+ ) -> tuple[PromptContext, PromptContextOrder, PromptContextImages, AugmentedContext]:
942
1003
  ccontext = CappedPromptContext(max_size=self.max_context_characters)
1004
+ print(".......................")
943
1005
  self.prepend_user_context(ccontext)
944
1006
  await self._build_context(ccontext)
945
1007
  if self.visual_llm:
946
1008
  await self._build_context_images(ccontext)
947
-
948
1009
  context = ccontext.output
949
1010
  context_images = ccontext.images
950
1011
  context_order = {text_block_id: order for order, text_block_id in enumerate(context.keys())}
951
- return context, context_order, context_images
1012
+ return context, context_order, context_images, self.augmented_context
952
1013
 
953
1014
  async def _build_context_images(self, context: CappedPromptContext) -> None:
954
1015
  ops = 0
@@ -1033,6 +1094,11 @@ class PromptContextBuilder:
1033
1094
  for paragraph in self.ordered_paragraphs:
1034
1095
  context[paragraph.id] = _clean_paragraph_text(paragraph)
1035
1096
 
1097
+ strategies_not_handled_here = [
1098
+ RagStrategyName.PREQUERIES,
1099
+ RagStrategyName.GRAPH,
1100
+ ]
1101
+
1036
1102
  full_resource: Optional[FullResourceStrategy] = None
1037
1103
  hierarchy: Optional[HierarchyResourceStrategy] = None
1038
1104
  neighbouring_paragraphs: Optional[NeighbouringParagraphsStrategy] = None
@@ -1056,9 +1122,7 @@ class PromptContextBuilder:
1056
1122
  neighbouring_paragraphs = cast(NeighbouringParagraphsStrategy, strategy)
1057
1123
  elif strategy.name == RagStrategyName.METADATA_EXTENSION:
1058
1124
  metadata_extension = cast(MetadataExtensionStrategy, strategy)
1059
- elif (
1060
- strategy.name != RagStrategyName.PREQUERIES and strategy.name != RagStrategyName.GRAPH
1061
- ): # pragma: no cover
1125
+ elif strategy.name not in strategies_not_handled_here: # pragma: no cover
1062
1126
  # Prequeries and graph are not handled here
1063
1127
  logger.warning(
1064
1128
  "Unknown rag strategy",
@@ -1074,16 +1138,26 @@ class PromptContextBuilder:
1074
1138
  self.resource,
1075
1139
  full_resource,
1076
1140
  self.metrics,
1141
+ self.augmented_context,
1077
1142
  )
1078
1143
  if metadata_extension:
1079
1144
  await extend_prompt_context_with_metadata(
1080
- context, self.kbid, metadata_extension, self.metrics
1145
+ context,
1146
+ self.kbid,
1147
+ metadata_extension,
1148
+ self.metrics,
1149
+ self.augmented_context,
1081
1150
  )
1082
1151
  return
1083
1152
 
1084
1153
  if hierarchy:
1085
1154
  await hierarchy_prompt_context(
1086
- context, self.kbid, self.ordered_paragraphs, hierarchy, self.metrics
1155
+ context,
1156
+ self.kbid,
1157
+ self.ordered_paragraphs,
1158
+ hierarchy,
1159
+ self.metrics,
1160
+ self.augmented_context,
1087
1161
  )
1088
1162
  if neighbouring_paragraphs:
1089
1163
  await neighbouring_paragraphs_prompt_context(
@@ -1092,6 +1166,7 @@ class PromptContextBuilder:
1092
1166
  self.ordered_paragraphs,
1093
1167
  neighbouring_paragraphs,
1094
1168
  self.metrics,
1169
+ self.augmented_context,
1095
1170
  )
1096
1171
  if field_extension:
1097
1172
  await field_extension_prompt_context(
@@ -1100,6 +1175,7 @@ class PromptContextBuilder:
1100
1175
  self.ordered_paragraphs,
1101
1176
  field_extension,
1102
1177
  self.metrics,
1178
+ self.augmented_context,
1103
1179
  )
1104
1180
  if conversational_strategy:
1105
1181
  await conversation_prompt_context(
@@ -1109,10 +1185,15 @@ class PromptContextBuilder:
1109
1185
  conversational_strategy,
1110
1186
  self.visual_llm,
1111
1187
  self.metrics,
1188
+ self.augmented_context,
1112
1189
  )
1113
1190
  if metadata_extension:
1114
1191
  await extend_prompt_context_with_metadata(
1115
- context, self.kbid, metadata_extension, self.metrics
1192
+ context,
1193
+ self.kbid,
1194
+ metadata_extension,
1195
+ self.metrics,
1196
+ self.augmented_context,
1116
1197
  )
1117
1198
 
1118
1199
 
@@ -1136,25 +1217,3 @@ def _clean_paragraph_text(paragraph: FindParagraph) -> str:
1136
1217
  # Do not send highlight marks on prompt context
1137
1218
  text = text.replace("<mark>", "").replace("</mark>", "")
1138
1219
  return text
1139
-
1140
-
1141
- def get_neighbouring_paragraph_indexes(
1142
- field_paragraphs: list[ParagraphId],
1143
- matching_paragraph: ParagraphId,
1144
- before: int,
1145
- after: int,
1146
- ) -> list[int]:
1147
- """
1148
- Returns the indexes of the neighbouring paragraphs to fetch (including the matching paragraph).
1149
- """
1150
- assert before >= 0
1151
- assert after >= 0
1152
- try:
1153
- matching_index = field_paragraphs.index(matching_paragraph)
1154
- except ValueError:
1155
- raise ParagraphIdNotFoundInExtractedMetadata(
1156
- f"Matching paragraph {matching_paragraph.full()} not found in extracted metadata"
1157
- )
1158
- start_index = max(0, matching_index - before)
1159
- end_index = min(len(field_paragraphs), matching_index + after + 1)
1160
- return list(range(start_index, end_index))