nucliadb 6.5.0.post4426__py3-none-any.whl → 6.5.0.post4484__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migrations/0037_backfill_catalog_facets.py +74 -0
- migrations/pg/0001_bootstrap.py +1 -1
- migrations/pg/0008_catalog_facets.py +43 -0
- migrations/pg/0009_extract_facets_safety.py +26 -0
- nucliadb/ingest/fields/base.py +18 -8
- nucliadb/ingest/orm/processor/pgcatalog.py +26 -0
- nucliadb/ingest/orm/resource.py +9 -4
- nucliadb/search/api/v1/catalog.py +14 -8
- nucliadb/search/search/chat/ask.py +12 -1
- nucliadb/search/search/chat/prompt.py +260 -201
- nucliadb/search/search/pgcatalog.py +174 -63
- {nucliadb-6.5.0.post4426.dist-info → nucliadb-6.5.0.post4484.dist-info}/METADATA +8 -8
- {nucliadb-6.5.0.post4426.dist-info → nucliadb-6.5.0.post4484.dist-info}/RECORD +16 -13
- {nucliadb-6.5.0.post4426.dist-info → nucliadb-6.5.0.post4484.dist-info}/WHEEL +0 -0
- {nucliadb-6.5.0.post4426.dist-info → nucliadb-6.5.0.post4484.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.5.0.post4426.dist-info → nucliadb-6.5.0.post4484.dist-info}/top_level.txt +0 -0
@@ -47,6 +47,8 @@ from nucliadb_models.labels import translate_alias_to_system_label
|
|
47
47
|
from nucliadb_models.metadata import Extra, Origin
|
48
48
|
from nucliadb_models.search import (
|
49
49
|
SCORE_TYPE,
|
50
|
+
AugmentedContext,
|
51
|
+
AugmentedTextBlock,
|
50
52
|
ConversationalStrategy,
|
51
53
|
FieldExtensionStrategy,
|
52
54
|
FindParagraph,
|
@@ -66,8 +68,10 @@ from nucliadb_models.search import (
|
|
66
68
|
RagStrategy,
|
67
69
|
RagStrategyName,
|
68
70
|
TableImageStrategy,
|
71
|
+
TextBlockAugmentationType,
|
69
72
|
)
|
70
73
|
from nucliadb_protos import resources_pb2
|
74
|
+
from nucliadb_protos.resources_pb2 import ExtractedText, FieldComputedMetadata
|
71
75
|
from nucliadb_utils.asyncio_utils import run_concurrently
|
72
76
|
from nucliadb_utils.utilities import get_storage
|
73
77
|
|
@@ -89,10 +93,7 @@ class ParagraphIdNotFoundInExtractedMetadata(Exception):
|
|
89
93
|
class CappedPromptContext:
|
90
94
|
"""
|
91
95
|
Class to keep track of the size (in number of characters) of the prompt context
|
92
|
-
and
|
93
|
-
|
94
|
-
This class will automatically trim data that exceeds the limit when it's being
|
95
|
-
set on the dictionary.
|
96
|
+
and automatically trim data that exceeds the limit when it's being set on the dictionary.
|
96
97
|
"""
|
97
98
|
|
98
99
|
def __init__(self, max_size: Optional[int]):
|
@@ -246,6 +247,7 @@ async def full_resource_prompt_context(
|
|
246
247
|
resource: Optional[str],
|
247
248
|
strategy: FullResourceStrategy,
|
248
249
|
metrics: Metrics,
|
250
|
+
augmented_context: AugmentedContext,
|
249
251
|
) -> None:
|
250
252
|
"""
|
251
253
|
Algorithm steps:
|
@@ -298,6 +300,12 @@ async def full_resource_prompt_context(
|
|
298
300
|
del context[tb_id]
|
299
301
|
# Add the extracted text of each field to the context.
|
300
302
|
context[field.full()] = extracted_text
|
303
|
+
augmented_context.fields[field.full()] = AugmentedTextBlock(
|
304
|
+
id=field.full(),
|
305
|
+
text=extracted_text,
|
306
|
+
augmentation_type=TextBlockAugmentationType.FULL_RESOURCE,
|
307
|
+
)
|
308
|
+
|
301
309
|
added_fields.add(field.full())
|
302
310
|
|
303
311
|
metrics.set("full_resource_ops", len(added_fields))
|
@@ -314,6 +322,7 @@ async def extend_prompt_context_with_metadata(
|
|
314
322
|
kbid: str,
|
315
323
|
strategy: MetadataExtensionStrategy,
|
316
324
|
metrics: Metrics,
|
325
|
+
augmented_context: AugmentedContext,
|
317
326
|
) -> None:
|
318
327
|
text_block_ids: list[TextBlockId] = []
|
319
328
|
for text_block_id in context.text_block_ids():
|
@@ -329,19 +338,23 @@ async def extend_prompt_context_with_metadata(
|
|
329
338
|
ops = 0
|
330
339
|
if MetadataExtensionType.ORIGIN in strategy.types:
|
331
340
|
ops += 1
|
332
|
-
await extend_prompt_context_with_origin_metadata(
|
341
|
+
await extend_prompt_context_with_origin_metadata(
|
342
|
+
context, kbid, text_block_ids, augmented_context
|
343
|
+
)
|
333
344
|
|
334
345
|
if MetadataExtensionType.CLASSIFICATION_LABELS in strategy.types:
|
335
346
|
ops += 1
|
336
|
-
await extend_prompt_context_with_classification_labels(
|
347
|
+
await extend_prompt_context_with_classification_labels(
|
348
|
+
context, kbid, text_block_ids, augmented_context
|
349
|
+
)
|
337
350
|
|
338
351
|
if MetadataExtensionType.NERS in strategy.types:
|
339
352
|
ops += 1
|
340
|
-
await extend_prompt_context_with_ner(context, kbid, text_block_ids)
|
353
|
+
await extend_prompt_context_with_ner(context, kbid, text_block_ids, augmented_context)
|
341
354
|
|
342
355
|
if MetadataExtensionType.EXTRA_METADATA in strategy.types:
|
343
356
|
ops += 1
|
344
|
-
await extend_prompt_context_with_extra_metadata(context, kbid, text_block_ids)
|
357
|
+
await extend_prompt_context_with_extra_metadata(context, kbid, text_block_ids, augmented_context)
|
345
358
|
|
346
359
|
metrics.set("metadata_extension_ops", ops * len(text_block_ids))
|
347
360
|
|
@@ -356,7 +369,9 @@ def parse_text_block_id(text_block_id: str) -> TextBlockId:
|
|
356
369
|
return FieldId.from_string(text_block_id)
|
357
370
|
|
358
371
|
|
359
|
-
async def extend_prompt_context_with_origin_metadata(
|
372
|
+
async def extend_prompt_context_with_origin_metadata(
|
373
|
+
context, kbid, text_block_ids: list[TextBlockId], augmented_context: AugmentedContext
|
374
|
+
):
|
360
375
|
async def _get_origin(kbid: str, rid: str) -> tuple[str, Optional[Origin]]:
|
361
376
|
origin = None
|
362
377
|
resource = await cache.get_resource(kbid, rid)
|
@@ -372,11 +387,19 @@ async def extend_prompt_context_with_origin_metadata(context, kbid, text_block_i
|
|
372
387
|
for tb_id in text_block_ids:
|
373
388
|
origin = rid_to_origin.get(tb_id.rid)
|
374
389
|
if origin is not None and tb_id.full() in context.output:
|
375
|
-
context
|
390
|
+
text = context.output.pop(tb_id.full())
|
391
|
+
extended_text = text + f"\n\nDOCUMENT METADATA AT ORIGIN:\n{to_yaml(origin)}"
|
392
|
+
context[tb_id.full()] = extended_text
|
393
|
+
augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
|
394
|
+
id=tb_id.full(),
|
395
|
+
text=extended_text,
|
396
|
+
parent=tb_id.full(),
|
397
|
+
augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
|
398
|
+
)
|
376
399
|
|
377
400
|
|
378
401
|
async def extend_prompt_context_with_classification_labels(
|
379
|
-
context, kbid, text_block_ids: list[TextBlockId]
|
402
|
+
context, kbid, text_block_ids: list[TextBlockId], augmented_context: AugmentedContext
|
380
403
|
):
|
381
404
|
async def _get_labels(kbid: str, _id: TextBlockId) -> tuple[TextBlockId, list[tuple[str, str]]]:
|
382
405
|
fid = _id if isinstance(_id, FieldId) else _id.field_id
|
@@ -402,13 +425,25 @@ async def extend_prompt_context_with_classification_labels(
|
|
402
425
|
for tb_id in text_block_ids:
|
403
426
|
labels = tb_id_to_labels.get(tb_id)
|
404
427
|
if labels is not None and tb_id.full() in context.output:
|
428
|
+
text = context.output.pop(tb_id.full())
|
429
|
+
|
405
430
|
labels_text = "DOCUMENT CLASSIFICATION LABELS:"
|
406
431
|
for labelset, label in labels:
|
407
432
|
labels_text += f"\n - {label} ({labelset})"
|
408
|
-
|
433
|
+
extended_text = text + "\n\n" + labels_text
|
434
|
+
|
435
|
+
context[tb_id.full()] = extended_text
|
436
|
+
augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
|
437
|
+
id=tb_id.full(),
|
438
|
+
text=extended_text,
|
439
|
+
parent=tb_id.full(),
|
440
|
+
augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
|
441
|
+
)
|
409
442
|
|
410
443
|
|
411
|
-
async def extend_prompt_context_with_ner(
|
444
|
+
async def extend_prompt_context_with_ner(
|
445
|
+
context, kbid, text_block_ids: list[TextBlockId], augmented_context: AugmentedContext
|
446
|
+
):
|
412
447
|
async def _get_ners(kbid: str, _id: TextBlockId) -> tuple[TextBlockId, dict[str, set[str]]]:
|
413
448
|
fid = _id if isinstance(_id, FieldId) else _id.field_id
|
414
449
|
ners: dict[str, set[str]] = {}
|
@@ -435,15 +470,28 @@ async def extend_prompt_context_with_ner(context, kbid, text_block_ids: list[Tex
|
|
435
470
|
for tb_id in text_block_ids:
|
436
471
|
ners = tb_id_to_ners.get(tb_id)
|
437
472
|
if ners is not None and tb_id.full() in context.output:
|
473
|
+
text = context.output.pop(tb_id.full())
|
474
|
+
|
438
475
|
ners_text = "DOCUMENT NAMED ENTITIES (NERs):"
|
439
476
|
for family, tokens in ners.items():
|
440
477
|
ners_text += f"\n - {family}:"
|
441
478
|
for token in sorted(list(tokens)):
|
442
479
|
ners_text += f"\n - {token}"
|
443
|
-
|
480
|
+
|
481
|
+
extended_text = text + "\n\n" + ners_text
|
482
|
+
|
483
|
+
context[tb_id.full()] = extended_text
|
484
|
+
augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
|
485
|
+
id=tb_id.full(),
|
486
|
+
text=extended_text,
|
487
|
+
parent=tb_id.full(),
|
488
|
+
augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
|
489
|
+
)
|
444
490
|
|
445
491
|
|
446
|
-
async def extend_prompt_context_with_extra_metadata(
|
492
|
+
async def extend_prompt_context_with_extra_metadata(
|
493
|
+
context, kbid, text_block_ids: list[TextBlockId], augmented_context: AugmentedContext
|
494
|
+
):
|
447
495
|
async def _get_extra(kbid: str, rid: str) -> tuple[str, Optional[Extra]]:
|
448
496
|
extra = None
|
449
497
|
resource = await cache.get_resource(kbid, rid)
|
@@ -459,7 +507,15 @@ async def extend_prompt_context_with_extra_metadata(context, kbid, text_block_id
|
|
459
507
|
for tb_id in text_block_ids:
|
460
508
|
extra = rid_to_extra.get(tb_id.rid)
|
461
509
|
if extra is not None and tb_id.full() in context.output:
|
462
|
-
context
|
510
|
+
text = context.output.pop(tb_id.full())
|
511
|
+
extended_text = text + f"\n\nDOCUMENT EXTRA METADATA:\n{to_yaml(extra)}"
|
512
|
+
context[tb_id.full()] = extended_text
|
513
|
+
augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
|
514
|
+
id=tb_id.full(),
|
515
|
+
text=extended_text,
|
516
|
+
parent=tb_id.full(),
|
517
|
+
augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
|
518
|
+
)
|
463
519
|
|
464
520
|
|
465
521
|
def to_yaml(obj: BaseModel) -> str:
|
@@ -477,6 +533,7 @@ async def field_extension_prompt_context(
|
|
477
533
|
ordered_paragraphs: list[FindParagraph],
|
478
534
|
strategy: FieldExtensionStrategy,
|
479
535
|
metrics: Metrics,
|
536
|
+
augmented_context: AugmentedContext,
|
480
537
|
) -> None:
|
481
538
|
"""
|
482
539
|
Algorithm steps:
|
@@ -518,115 +575,25 @@ async def field_extension_prompt_context(
|
|
518
575
|
if tb_id.startswith(field.full()):
|
519
576
|
del context[tb_id]
|
520
577
|
# Add the extracted text of each field to the beginning of the context.
|
521
|
-
|
578
|
+
if field.full() not in context.output:
|
579
|
+
context[field.full()] = extracted_text
|
580
|
+
augmented_context.fields[field.full()] = AugmentedTextBlock(
|
581
|
+
id=field.full(),
|
582
|
+
text=extracted_text,
|
583
|
+
augmentation_type=TextBlockAugmentationType.FIELD_EXTENSION,
|
584
|
+
)
|
522
585
|
|
523
586
|
# Add the extracted text of each paragraph to the end of the context.
|
524
587
|
for paragraph in ordered_paragraphs:
|
525
|
-
|
526
|
-
|
527
|
-
|
528
|
-
async def get_paragraph_text_with_neighbours(
|
529
|
-
kbid: str,
|
530
|
-
pid: ParagraphId,
|
531
|
-
field_paragraphs: list[ParagraphId],
|
532
|
-
before: int = 0,
|
533
|
-
after: int = 0,
|
534
|
-
) -> tuple[ParagraphId, str]:
|
535
|
-
"""
|
536
|
-
This function will get the paragraph text of the paragraph with the neighbouring paragraphs included.
|
537
|
-
Parameters:
|
538
|
-
kbid: The knowledge box id.
|
539
|
-
pid: The matching paragraph id.
|
540
|
-
field_paragraphs: The list of paragraph ids of the field.
|
541
|
-
before: The number of paragraphs to include before the matching paragraph.
|
542
|
-
after: The number of paragraphs to include after the matching paragraph.
|
543
|
-
"""
|
544
|
-
|
545
|
-
async def _get_paragraph_text(
|
546
|
-
kbid: str,
|
547
|
-
pid: ParagraphId,
|
548
|
-
) -> tuple[ParagraphId, str]:
|
549
|
-
return pid, await get_paragraph_text(
|
550
|
-
kbid=kbid,
|
551
|
-
paragraph_id=pid,
|
552
|
-
log_on_missing_field=True,
|
553
|
-
)
|
554
|
-
|
555
|
-
ops = []
|
556
|
-
try:
|
557
|
-
for paragraph_index in get_neighbouring_paragraph_indexes(
|
558
|
-
field_paragraphs=field_paragraphs,
|
559
|
-
matching_paragraph=pid,
|
560
|
-
before=before,
|
561
|
-
after=after,
|
562
|
-
):
|
563
|
-
neighbour_pid = field_paragraphs[paragraph_index]
|
564
|
-
ops.append(
|
565
|
-
asyncio.create_task(
|
566
|
-
_get_paragraph_text(
|
567
|
-
kbid=kbid,
|
568
|
-
pid=neighbour_pid,
|
569
|
-
)
|
570
|
-
)
|
571
|
-
)
|
572
|
-
except ParagraphIdNotFoundInExtractedMetadata:
|
573
|
-
logger.warning(
|
574
|
-
"Could not find matching paragraph in extracted metadata. This is odd and needs to be investigated.",
|
575
|
-
extra={
|
576
|
-
"kbid": kbid,
|
577
|
-
"matching_paragraph": pid.full(),
|
578
|
-
"field_paragraphs": [p.full() for p in field_paragraphs],
|
579
|
-
},
|
580
|
-
)
|
581
|
-
# If we could not find the matching paragraph in the extracted metadata, we can't retrieve
|
582
|
-
# the neighbouring paragraphs and we simply fetch the text of the matching paragraph.
|
583
|
-
ops.append(
|
584
|
-
asyncio.create_task(
|
585
|
-
_get_paragraph_text(
|
586
|
-
kbid=kbid,
|
587
|
-
pid=pid,
|
588
|
-
)
|
589
|
-
)
|
590
|
-
)
|
591
|
-
|
592
|
-
results = []
|
593
|
-
if len(ops) > 0:
|
594
|
-
results = await asyncio.gather(*ops)
|
595
|
-
|
596
|
-
# Sort the results by the paragraph start
|
597
|
-
results.sort(key=lambda x: x[0].paragraph_start)
|
598
|
-
paragraph_texts = []
|
599
|
-
for _, text in results:
|
600
|
-
if text != "":
|
601
|
-
paragraph_texts.append(text)
|
602
|
-
return pid, "\n\n".join(paragraph_texts)
|
588
|
+
if paragraph.id not in context.output:
|
589
|
+
context[paragraph.id] = _clean_paragraph_text(paragraph)
|
603
590
|
|
604
591
|
|
605
|
-
async def
|
606
|
-
kbid
|
607
|
-
field: FieldId,
|
608
|
-
paragraphs: list[ParagraphId],
|
609
|
-
) -> None:
|
610
|
-
"""
|
611
|
-
Modifies the paragraphs list by adding the paragraph ids of the field, sorted by position.
|
612
|
-
"""
|
613
|
-
resource = await cache.get_resource(kbid, field.rid)
|
592
|
+
async def get_orm_field(kbid: str, field_id: FieldId) -> Optional[Field]:
|
593
|
+
resource = await cache.get_resource(kbid, field_id.rid)
|
614
594
|
if resource is None: # pragma: no cover
|
615
|
-
return
|
616
|
-
|
617
|
-
field_metadata: Optional[resources_pb2.FieldComputedMetadata] = await field_obj.get_field_metadata(
|
618
|
-
force=True
|
619
|
-
)
|
620
|
-
if field_metadata is None: # pragma: no cover
|
621
|
-
return
|
622
|
-
for paragraph in field_metadata.metadata.paragraphs:
|
623
|
-
paragraphs.append(
|
624
|
-
ParagraphId(
|
625
|
-
field_id=field,
|
626
|
-
paragraph_start=paragraph.start,
|
627
|
-
paragraph_end=paragraph.end,
|
628
|
-
)
|
629
|
-
)
|
595
|
+
return None
|
596
|
+
return await resource.get_field(key=field_id.key, type=field_id.pb_type, load=False)
|
630
597
|
|
631
598
|
|
632
599
|
async def neighbouring_paragraphs_prompt_context(
|
@@ -635,61 +602,114 @@ async def neighbouring_paragraphs_prompt_context(
|
|
635
602
|
ordered_text_blocks: list[FindParagraph],
|
636
603
|
strategy: NeighbouringParagraphsStrategy,
|
637
604
|
metrics: Metrics,
|
605
|
+
augmented_context: AugmentedContext,
|
638
606
|
) -> None:
|
639
607
|
"""
|
640
608
|
This function will get the paragraph texts and then craft a context with the neighbouring paragraphs of the
|
641
|
-
paragraphs in the ordered_paragraphs list.
|
609
|
+
paragraphs in the ordered_paragraphs list.
|
642
610
|
"""
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
611
|
+
retrieved_paragraphs_ids = [
|
612
|
+
ParagraphId.from_string(text_block.id) for text_block in ordered_text_blocks
|
613
|
+
]
|
614
|
+
unique_field_ids = list({pid.field_id for pid in retrieved_paragraphs_ids})
|
615
|
+
|
616
|
+
# Get extracted texts and metadatas for all fields
|
617
|
+
fm_ops = []
|
618
|
+
et_ops = []
|
619
|
+
for field_id in unique_field_ids:
|
620
|
+
field = await get_orm_field(kbid, field_id)
|
621
|
+
if field is None:
|
622
|
+
continue
|
623
|
+
fm_ops.append(asyncio.create_task(field.get_field_metadata()))
|
624
|
+
et_ops.append(asyncio.create_task(field.get_extracted_text()))
|
625
|
+
|
626
|
+
field_metadatas: dict[FieldId, FieldComputedMetadata] = {
|
627
|
+
fid: fm for fid, fm in zip(unique_field_ids, await asyncio.gather(*fm_ops)) if fm is not None
|
647
628
|
}
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
629
|
+
extracted_texts: dict[FieldId, ExtractedText] = {
|
630
|
+
fid: et for fid, et in zip(unique_field_ids, await asyncio.gather(*et_ops)) if et is not None
|
631
|
+
}
|
632
|
+
|
633
|
+
def _get_paragraph_text(extracted_text: ExtractedText, pid: ParagraphId) -> str:
|
634
|
+
if pid.field_id.subfield_id:
|
635
|
+
text = extracted_text.split_text.get(pid.field_id.subfield_id) or ""
|
636
|
+
else:
|
637
|
+
text = extracted_text.text
|
638
|
+
return text[pid.paragraph_start : pid.paragraph_end]
|
639
|
+
|
640
|
+
for pid in retrieved_paragraphs_ids:
|
641
|
+
# Add the retrieved paragraph first
|
642
|
+
field_extracted_text = extracted_texts.get(pid.field_id, None)
|
643
|
+
if field_extracted_text is None:
|
644
|
+
continue
|
645
|
+
ptext = _get_paragraph_text(field_extracted_text, pid)
|
646
|
+
if ptext:
|
647
|
+
context[pid.full()] = ptext
|
648
|
+
|
649
|
+
# Now add the neighbouring paragraphs
|
650
|
+
field_extracted_metadata = field_metadatas.get(pid.field_id, None)
|
651
|
+
if field_extracted_metadata is None:
|
652
|
+
continue
|
653
|
+
|
654
|
+
field_pids = [
|
655
|
+
ParagraphId(
|
656
|
+
field_id=pid.field_id,
|
657
|
+
paragraph_start=p.start,
|
658
|
+
paragraph_end=p.end,
|
671
659
|
)
|
672
|
-
|
673
|
-
|
674
|
-
|
660
|
+
for p in field_extracted_metadata.metadata.paragraphs
|
661
|
+
]
|
662
|
+
try:
|
663
|
+
index = field_pids.index(pid)
|
664
|
+
except IndexError:
|
665
|
+
continue
|
675
666
|
|
676
|
-
|
667
|
+
for neighbour_index in get_neighbouring_indices(
|
668
|
+
index=index,
|
669
|
+
before=strategy.before,
|
670
|
+
after=strategy.after,
|
671
|
+
field_pids=field_pids,
|
672
|
+
):
|
673
|
+
if neighbour_index == index:
|
674
|
+
# Already handled above
|
675
|
+
continue
|
676
|
+
try:
|
677
|
+
npid = field_pids[neighbour_index]
|
678
|
+
except IndexError:
|
679
|
+
continue
|
680
|
+
if npid in retrieved_paragraphs_ids or npid.full() in context.output:
|
681
|
+
# Already added above
|
682
|
+
continue
|
683
|
+
ptext = _get_paragraph_text(field_extracted_text, npid)
|
684
|
+
if not ptext:
|
685
|
+
continue
|
686
|
+
context[npid.full()] = ptext
|
687
|
+
augmented_context.paragraphs[npid.full()] = AugmentedTextBlock(
|
688
|
+
id=npid.full(),
|
689
|
+
text=ptext,
|
690
|
+
parent=pid.full(),
|
691
|
+
augmentation_type=TextBlockAugmentationType.NEIGHBOURING_PARAGRAPHS,
|
692
|
+
)
|
677
693
|
|
678
|
-
metrics.set("neighbouring_paragraphs_ops", len(
|
694
|
+
metrics.set("neighbouring_paragraphs_ops", len(augmented_context.paragraphs))
|
679
695
|
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
696
|
+
|
697
|
+
def get_neighbouring_indices(
|
698
|
+
index: int, before: int, after: int, field_pids: list[ParagraphId]
|
699
|
+
) -> list[int]:
|
700
|
+
lb_index = max(0, index - before)
|
701
|
+
ub_index = min(len(field_pids), index + after + 1)
|
702
|
+
return list(range(lb_index, index)) + list(range(index + 1, ub_index))
|
684
703
|
|
685
704
|
|
686
705
|
async def conversation_prompt_context(
|
687
706
|
context: CappedPromptContext,
|
688
707
|
kbid: str,
|
689
708
|
ordered_paragraphs: list[FindParagraph],
|
690
|
-
|
709
|
+
strategy: ConversationalStrategy,
|
691
710
|
visual_llm: bool,
|
692
711
|
metrics: Metrics,
|
712
|
+
augmented_context: AugmentedContext,
|
693
713
|
):
|
694
714
|
analyzed_fields: List[str] = []
|
695
715
|
ops = 0
|
@@ -721,7 +741,7 @@ async def conversation_prompt_context(
|
|
721
741
|
cmetadata = await field_obj.get_metadata()
|
722
742
|
|
723
743
|
attachments: List[resources_pb2.FieldRef] = []
|
724
|
-
if
|
744
|
+
if strategy.full:
|
725
745
|
ops += 5
|
726
746
|
extracted_text = await field_obj.get_extracted_text()
|
727
747
|
for current_page in range(1, cmetadata.pages + 1):
|
@@ -734,8 +754,16 @@ async def conversation_prompt_context(
|
|
734
754
|
else:
|
735
755
|
text = message.content.text.strip()
|
736
756
|
pid = f"{rid}/{field_type}/{field_id}/{ident}/0-{len(text) + 1}"
|
737
|
-
context[pid] = text
|
738
757
|
attachments.extend(message.content.attachments_fields)
|
758
|
+
if pid in context.output:
|
759
|
+
continue
|
760
|
+
context[pid] = text
|
761
|
+
augmented_context.paragraphs[pid] = AugmentedTextBlock(
|
762
|
+
id=pid,
|
763
|
+
text=text,
|
764
|
+
parent=paragraph.id,
|
765
|
+
augmentation_type=TextBlockAugmentationType.CONVERSATION,
|
766
|
+
)
|
739
767
|
else:
|
740
768
|
# Add first message
|
741
769
|
extracted_text = await field_obj.get_extracted_text()
|
@@ -747,13 +775,19 @@ async def conversation_prompt_context(
|
|
747
775
|
text = extracted_text.split_text.get(ident, message.content.text.strip())
|
748
776
|
else:
|
749
777
|
text = message.content.text.strip()
|
778
|
+
attachments.extend(message.content.attachments_fields)
|
750
779
|
pid = f"{rid}/{field_type}/{field_id}/{ident}/0-{len(text) + 1}"
|
780
|
+
if pid in context.output:
|
781
|
+
continue
|
751
782
|
context[pid] = text
|
752
|
-
|
783
|
+
augmented_context.paragraphs[pid] = AugmentedTextBlock(
|
784
|
+
id=pid,
|
785
|
+
text=text,
|
786
|
+
parent=paragraph.id,
|
787
|
+
augmentation_type=TextBlockAugmentationType.CONVERSATION,
|
788
|
+
)
|
753
789
|
|
754
|
-
messages: Deque[resources_pb2.Message] = deque(
|
755
|
-
maxlen=conversational_strategy.max_messages
|
756
|
-
)
|
790
|
+
messages: Deque[resources_pb2.Message] = deque(maxlen=strategy.max_messages)
|
757
791
|
|
758
792
|
pending = -1
|
759
793
|
for page in range(1, cmetadata.pages + 1):
|
@@ -764,7 +798,7 @@ async def conversation_prompt_context(
|
|
764
798
|
if pending > 0:
|
765
799
|
pending -= 1
|
766
800
|
if message.ident == mident:
|
767
|
-
pending = (
|
801
|
+
pending = (strategy.max_messages - 1) // 2
|
768
802
|
if pending == 0:
|
769
803
|
break
|
770
804
|
if pending == 0:
|
@@ -773,11 +807,19 @@ async def conversation_prompt_context(
|
|
773
807
|
for message in messages:
|
774
808
|
ops += 1
|
775
809
|
text = message.content.text.strip()
|
810
|
+
attachments.extend(message.content.attachments_fields)
|
776
811
|
pid = f"{rid}/{field_type}/{field_id}/{message.ident}/0-{len(message.content.text) + 1}"
|
812
|
+
if pid in context.output:
|
813
|
+
continue
|
777
814
|
context[pid] = text
|
778
|
-
|
815
|
+
augmented_context.paragraphs[pid] = AugmentedTextBlock(
|
816
|
+
id=pid,
|
817
|
+
text=text,
|
818
|
+
parent=paragraph.id,
|
819
|
+
augmentation_type=TextBlockAugmentationType.CONVERSATION,
|
820
|
+
)
|
779
821
|
|
780
|
-
if
|
822
|
+
if strategy.attachments_text:
|
781
823
|
# add on the context the images if vlm enabled
|
782
824
|
for attachment in attachments:
|
783
825
|
ops += 1
|
@@ -787,9 +829,18 @@ async def conversation_prompt_context(
|
|
787
829
|
extracted_text = await field.get_extracted_text()
|
788
830
|
if extracted_text is not None:
|
789
831
|
pid = f"{rid}/{field_type}/{attachment.field_id}/0-{len(extracted_text.text) + 1}"
|
790
|
-
|
791
|
-
|
792
|
-
|
832
|
+
if pid in context.output:
|
833
|
+
continue
|
834
|
+
text = f"Attachment {attachment.field_id}: {extracted_text.text}\n\n"
|
835
|
+
context[pid] = text
|
836
|
+
augmented_context.paragraphs[pid] = AugmentedTextBlock(
|
837
|
+
id=pid,
|
838
|
+
text=text,
|
839
|
+
parent=paragraph.id,
|
840
|
+
augmentation_type=TextBlockAugmentationType.CONVERSATION,
|
841
|
+
)
|
842
|
+
|
843
|
+
if strategy.attachments_images and visual_llm:
|
793
844
|
for attachment in attachments:
|
794
845
|
ops += 1
|
795
846
|
file_field: File = await resource.get_field(
|
@@ -810,6 +861,7 @@ async def hierarchy_prompt_context(
|
|
810
861
|
ordered_paragraphs: list[FindParagraph],
|
811
862
|
strategy: HierarchyResourceStrategy,
|
812
863
|
metrics: Metrics,
|
864
|
+
augmented_context: AugmentedContext,
|
813
865
|
) -> None:
|
814
866
|
"""
|
815
867
|
This function will get the paragraph texts (possibly with extra characters, if extra_characters > 0) and then
|
@@ -870,6 +922,7 @@ async def hierarchy_prompt_context(
|
|
870
922
|
resources[rid].paragraphs.append((paragraph, extended_paragraph_text))
|
871
923
|
|
872
924
|
metrics.set("hierarchy_ops", len(resources))
|
925
|
+
augmented_paragraphs = set()
|
873
926
|
|
874
927
|
# Modify the first paragraph of each resource to include the title and summary of the resource, as well as the
|
875
928
|
# extended paragraph text of all the paragraphs in the resource.
|
@@ -889,13 +942,20 @@ async def hierarchy_prompt_context(
|
|
889
942
|
if first_paragraph is not None:
|
890
943
|
# The first paragraph is the only one holding the hierarchy information
|
891
944
|
first_paragraph.text = f"DOCUMENT: {title_text} \n SUMMARY: {summary_text} \n RESOURCE CONTENT: {text_with_hierarchy}"
|
945
|
+
augmented_paragraphs.add(first_paragraph.id)
|
892
946
|
|
893
947
|
# Now that the paragraphs have been modified, we can add them to the context
|
894
948
|
for paragraph in ordered_paragraphs_copy:
|
895
949
|
if paragraph.text == "":
|
896
950
|
# Skip paragraphs that were cleared in the hierarchy expansion
|
897
951
|
continue
|
898
|
-
|
952
|
+
paragraph_text = _clean_paragraph_text(paragraph)
|
953
|
+
context[paragraph.id] = paragraph_text
|
954
|
+
if paragraph.id in augmented_paragraphs:
|
955
|
+
field_id = ParagraphId.from_string(paragraph.id).field_id.full()
|
956
|
+
augmented_context.fields[field_id] = AugmentedTextBlock(
|
957
|
+
id=field_id, text=paragraph_text, augmentation_type=TextBlockAugmentationType.HIERARCHY
|
958
|
+
)
|
899
959
|
return
|
900
960
|
|
901
961
|
|
@@ -927,6 +987,7 @@ class PromptContextBuilder:
|
|
927
987
|
self.max_context_characters = max_context_characters
|
928
988
|
self.visual_llm = visual_llm
|
929
989
|
self.metrics = metrics
|
990
|
+
self.augmented_context = AugmentedContext(paragraphs={}, fields={})
|
930
991
|
|
931
992
|
def prepend_user_context(self, context: CappedPromptContext):
|
932
993
|
# Chat extra context passed by the user is the most important, therefore
|
@@ -938,17 +999,17 @@ class PromptContextBuilder:
|
|
938
999
|
|
939
1000
|
async def build(
|
940
1001
|
self,
|
941
|
-
) -> tuple[PromptContext, PromptContextOrder, PromptContextImages]:
|
1002
|
+
) -> tuple[PromptContext, PromptContextOrder, PromptContextImages, AugmentedContext]:
|
942
1003
|
ccontext = CappedPromptContext(max_size=self.max_context_characters)
|
1004
|
+
print(".......................")
|
943
1005
|
self.prepend_user_context(ccontext)
|
944
1006
|
await self._build_context(ccontext)
|
945
1007
|
if self.visual_llm:
|
946
1008
|
await self._build_context_images(ccontext)
|
947
|
-
|
948
1009
|
context = ccontext.output
|
949
1010
|
context_images = ccontext.images
|
950
1011
|
context_order = {text_block_id: order for order, text_block_id in enumerate(context.keys())}
|
951
|
-
return context, context_order, context_images
|
1012
|
+
return context, context_order, context_images, self.augmented_context
|
952
1013
|
|
953
1014
|
async def _build_context_images(self, context: CappedPromptContext) -> None:
|
954
1015
|
ops = 0
|
@@ -1033,6 +1094,11 @@ class PromptContextBuilder:
|
|
1033
1094
|
for paragraph in self.ordered_paragraphs:
|
1034
1095
|
context[paragraph.id] = _clean_paragraph_text(paragraph)
|
1035
1096
|
|
1097
|
+
strategies_not_handled_here = [
|
1098
|
+
RagStrategyName.PREQUERIES,
|
1099
|
+
RagStrategyName.GRAPH,
|
1100
|
+
]
|
1101
|
+
|
1036
1102
|
full_resource: Optional[FullResourceStrategy] = None
|
1037
1103
|
hierarchy: Optional[HierarchyResourceStrategy] = None
|
1038
1104
|
neighbouring_paragraphs: Optional[NeighbouringParagraphsStrategy] = None
|
@@ -1056,9 +1122,7 @@ class PromptContextBuilder:
|
|
1056
1122
|
neighbouring_paragraphs = cast(NeighbouringParagraphsStrategy, strategy)
|
1057
1123
|
elif strategy.name == RagStrategyName.METADATA_EXTENSION:
|
1058
1124
|
metadata_extension = cast(MetadataExtensionStrategy, strategy)
|
1059
|
-
elif
|
1060
|
-
strategy.name != RagStrategyName.PREQUERIES and strategy.name != RagStrategyName.GRAPH
|
1061
|
-
): # pragma: no cover
|
1125
|
+
elif strategy.name not in strategies_not_handled_here: # pragma: no cover
|
1062
1126
|
# Prequeries and graph are not handled here
|
1063
1127
|
logger.warning(
|
1064
1128
|
"Unknown rag strategy",
|
@@ -1074,16 +1138,26 @@ class PromptContextBuilder:
|
|
1074
1138
|
self.resource,
|
1075
1139
|
full_resource,
|
1076
1140
|
self.metrics,
|
1141
|
+
self.augmented_context,
|
1077
1142
|
)
|
1078
1143
|
if metadata_extension:
|
1079
1144
|
await extend_prompt_context_with_metadata(
|
1080
|
-
context,
|
1145
|
+
context,
|
1146
|
+
self.kbid,
|
1147
|
+
metadata_extension,
|
1148
|
+
self.metrics,
|
1149
|
+
self.augmented_context,
|
1081
1150
|
)
|
1082
1151
|
return
|
1083
1152
|
|
1084
1153
|
if hierarchy:
|
1085
1154
|
await hierarchy_prompt_context(
|
1086
|
-
context,
|
1155
|
+
context,
|
1156
|
+
self.kbid,
|
1157
|
+
self.ordered_paragraphs,
|
1158
|
+
hierarchy,
|
1159
|
+
self.metrics,
|
1160
|
+
self.augmented_context,
|
1087
1161
|
)
|
1088
1162
|
if neighbouring_paragraphs:
|
1089
1163
|
await neighbouring_paragraphs_prompt_context(
|
@@ -1092,6 +1166,7 @@ class PromptContextBuilder:
|
|
1092
1166
|
self.ordered_paragraphs,
|
1093
1167
|
neighbouring_paragraphs,
|
1094
1168
|
self.metrics,
|
1169
|
+
self.augmented_context,
|
1095
1170
|
)
|
1096
1171
|
if field_extension:
|
1097
1172
|
await field_extension_prompt_context(
|
@@ -1100,6 +1175,7 @@ class PromptContextBuilder:
|
|
1100
1175
|
self.ordered_paragraphs,
|
1101
1176
|
field_extension,
|
1102
1177
|
self.metrics,
|
1178
|
+
self.augmented_context,
|
1103
1179
|
)
|
1104
1180
|
if conversational_strategy:
|
1105
1181
|
await conversation_prompt_context(
|
@@ -1109,10 +1185,15 @@ class PromptContextBuilder:
|
|
1109
1185
|
conversational_strategy,
|
1110
1186
|
self.visual_llm,
|
1111
1187
|
self.metrics,
|
1188
|
+
self.augmented_context,
|
1112
1189
|
)
|
1113
1190
|
if metadata_extension:
|
1114
1191
|
await extend_prompt_context_with_metadata(
|
1115
|
-
context,
|
1192
|
+
context,
|
1193
|
+
self.kbid,
|
1194
|
+
metadata_extension,
|
1195
|
+
self.metrics,
|
1196
|
+
self.augmented_context,
|
1116
1197
|
)
|
1117
1198
|
|
1118
1199
|
|
@@ -1136,25 +1217,3 @@ def _clean_paragraph_text(paragraph: FindParagraph) -> str:
|
|
1136
1217
|
# Do not send highlight marks on prompt context
|
1137
1218
|
text = text.replace("<mark>", "").replace("</mark>", "")
|
1138
1219
|
return text
|
1139
|
-
|
1140
|
-
|
1141
|
-
def get_neighbouring_paragraph_indexes(
|
1142
|
-
field_paragraphs: list[ParagraphId],
|
1143
|
-
matching_paragraph: ParagraphId,
|
1144
|
-
before: int,
|
1145
|
-
after: int,
|
1146
|
-
) -> list[int]:
|
1147
|
-
"""
|
1148
|
-
Returns the indexes of the neighbouring paragraphs to fetch (including the matching paragraph).
|
1149
|
-
"""
|
1150
|
-
assert before >= 0
|
1151
|
-
assert after >= 0
|
1152
|
-
try:
|
1153
|
-
matching_index = field_paragraphs.index(matching_paragraph)
|
1154
|
-
except ValueError:
|
1155
|
-
raise ParagraphIdNotFoundInExtractedMetadata(
|
1156
|
-
f"Matching paragraph {matching_paragraph.full()} not found in extracted metadata"
|
1157
|
-
)
|
1158
|
-
start_index = max(0, matching_index - before)
|
1159
|
-
end_index = min(len(field_paragraphs), matching_index + after + 1)
|
1160
|
-
return list(range(start_index, end_index))
|