nucliadb 6.9.1.post5192__py3-none-any.whl → 6.10.0.post5705__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migrations/0023_backfill_pg_catalog.py +2 -2
- migrations/0029_backfill_field_status.py +3 -4
- migrations/0032_remove_old_relations.py +2 -3
- migrations/0038_backfill_catalog_field_labels.py +2 -2
- migrations/0039_backfill_converation_splits_metadata.py +2 -2
- migrations/0041_reindex_conversations.py +137 -0
- migrations/pg/0010_shards_index.py +34 -0
- nucliadb/search/api/v1/resource/utils.py → migrations/pg/0011_catalog_statistics.py +5 -6
- migrations/pg/0012_catalog_statistics_undo.py +26 -0
- nucliadb/backups/create.py +2 -15
- nucliadb/backups/restore.py +4 -15
- nucliadb/backups/tasks.py +4 -1
- nucliadb/common/back_pressure/cache.py +2 -3
- nucliadb/common/back_pressure/materializer.py +7 -13
- nucliadb/common/back_pressure/settings.py +6 -6
- nucliadb/common/back_pressure/utils.py +1 -0
- nucliadb/common/cache.py +9 -9
- nucliadb/common/catalog/interface.py +12 -12
- nucliadb/common/catalog/pg.py +41 -29
- nucliadb/common/catalog/utils.py +3 -3
- nucliadb/common/cluster/manager.py +5 -4
- nucliadb/common/cluster/rebalance.py +483 -114
- nucliadb/common/cluster/rollover.py +25 -9
- nucliadb/common/cluster/settings.py +3 -8
- nucliadb/common/cluster/utils.py +34 -8
- nucliadb/common/context/__init__.py +7 -8
- nucliadb/common/context/fastapi.py +1 -2
- nucliadb/common/datamanagers/__init__.py +2 -4
- nucliadb/common/datamanagers/atomic.py +4 -2
- nucliadb/common/datamanagers/cluster.py +1 -2
- nucliadb/common/datamanagers/fields.py +3 -4
- nucliadb/common/datamanagers/kb.py +6 -6
- nucliadb/common/datamanagers/labels.py +2 -3
- nucliadb/common/datamanagers/resources.py +10 -33
- nucliadb/common/datamanagers/rollover.py +5 -7
- nucliadb/common/datamanagers/search_configurations.py +1 -2
- nucliadb/common/datamanagers/synonyms.py +1 -2
- nucliadb/common/datamanagers/utils.py +4 -4
- nucliadb/common/datamanagers/vectorsets.py +4 -4
- nucliadb/common/external_index_providers/base.py +32 -5
- nucliadb/common/external_index_providers/manager.py +4 -5
- nucliadb/common/filter_expression.py +128 -40
- nucliadb/common/http_clients/processing.py +12 -23
- nucliadb/common/ids.py +6 -4
- nucliadb/common/locking.py +1 -2
- nucliadb/common/maindb/driver.py +9 -8
- nucliadb/common/maindb/local.py +5 -5
- nucliadb/common/maindb/pg.py +9 -8
- nucliadb/common/nidx.py +3 -4
- nucliadb/export_import/datamanager.py +4 -3
- nucliadb/export_import/exporter.py +11 -19
- nucliadb/export_import/importer.py +13 -6
- nucliadb/export_import/tasks.py +2 -0
- nucliadb/export_import/utils.py +6 -18
- nucliadb/health.py +2 -2
- nucliadb/ingest/app.py +8 -8
- nucliadb/ingest/consumer/consumer.py +8 -10
- nucliadb/ingest/consumer/pull.py +3 -8
- nucliadb/ingest/consumer/service.py +3 -3
- nucliadb/ingest/consumer/utils.py +1 -1
- nucliadb/ingest/fields/base.py +28 -49
- nucliadb/ingest/fields/conversation.py +12 -12
- nucliadb/ingest/fields/exceptions.py +1 -2
- nucliadb/ingest/fields/file.py +22 -8
- nucliadb/ingest/fields/link.py +7 -7
- nucliadb/ingest/fields/text.py +2 -3
- nucliadb/ingest/orm/brain_v2.py +78 -64
- nucliadb/ingest/orm/broker_message.py +2 -4
- nucliadb/ingest/orm/entities.py +10 -209
- nucliadb/ingest/orm/index_message.py +4 -4
- nucliadb/ingest/orm/knowledgebox.py +18 -27
- nucliadb/ingest/orm/processor/auditing.py +1 -3
- nucliadb/ingest/orm/processor/data_augmentation.py +1 -2
- nucliadb/ingest/orm/processor/processor.py +27 -27
- nucliadb/ingest/orm/processor/sequence_manager.py +1 -2
- nucliadb/ingest/orm/resource.py +72 -70
- nucliadb/ingest/orm/utils.py +1 -1
- nucliadb/ingest/processing.py +17 -17
- nucliadb/ingest/serialize.py +202 -145
- nucliadb/ingest/service/writer.py +3 -109
- nucliadb/ingest/settings.py +3 -4
- nucliadb/ingest/utils.py +1 -2
- nucliadb/learning_proxy.py +11 -11
- nucliadb/metrics_exporter.py +5 -4
- nucliadb/middleware/__init__.py +82 -1
- nucliadb/migrator/datamanager.py +3 -4
- nucliadb/migrator/migrator.py +1 -2
- nucliadb/migrator/models.py +1 -2
- nucliadb/migrator/settings.py +1 -2
- nucliadb/models/internal/augment.py +614 -0
- nucliadb/models/internal/processing.py +19 -19
- nucliadb/openapi.py +2 -2
- nucliadb/purge/__init__.py +3 -8
- nucliadb/purge/orphan_shards.py +1 -2
- nucliadb/reader/__init__.py +5 -0
- nucliadb/reader/api/models.py +6 -13
- nucliadb/reader/api/v1/download.py +59 -38
- nucliadb/reader/api/v1/export_import.py +4 -4
- nucliadb/reader/api/v1/learning_config.py +24 -4
- nucliadb/reader/api/v1/resource.py +61 -9
- nucliadb/reader/api/v1/services.py +18 -14
- nucliadb/reader/app.py +3 -1
- nucliadb/reader/reader/notifications.py +1 -2
- nucliadb/search/api/v1/__init__.py +2 -0
- nucliadb/search/api/v1/ask.py +3 -4
- nucliadb/search/api/v1/augment.py +585 -0
- nucliadb/search/api/v1/catalog.py +11 -15
- nucliadb/search/api/v1/find.py +16 -22
- nucliadb/search/api/v1/hydrate.py +25 -25
- nucliadb/search/api/v1/knowledgebox.py +1 -2
- nucliadb/search/api/v1/predict_proxy.py +1 -2
- nucliadb/search/api/v1/resource/ask.py +7 -7
- nucliadb/search/api/v1/resource/ingestion_agents.py +5 -6
- nucliadb/search/api/v1/resource/search.py +9 -11
- nucliadb/search/api/v1/retrieve.py +130 -0
- nucliadb/search/api/v1/search.py +28 -32
- nucliadb/search/api/v1/suggest.py +11 -14
- nucliadb/search/api/v1/summarize.py +1 -2
- nucliadb/search/api/v1/utils.py +2 -2
- nucliadb/search/app.py +3 -2
- nucliadb/search/augmentor/__init__.py +21 -0
- nucliadb/search/augmentor/augmentor.py +232 -0
- nucliadb/search/augmentor/fields.py +704 -0
- nucliadb/search/augmentor/metrics.py +24 -0
- nucliadb/search/augmentor/paragraphs.py +334 -0
- nucliadb/search/augmentor/resources.py +238 -0
- nucliadb/search/augmentor/utils.py +33 -0
- nucliadb/search/lifecycle.py +3 -1
- nucliadb/search/predict.py +24 -17
- nucliadb/search/predict_models.py +8 -9
- nucliadb/search/requesters/utils.py +11 -10
- nucliadb/search/search/cache.py +19 -23
- nucliadb/search/search/chat/ask.py +88 -59
- nucliadb/search/search/chat/exceptions.py +3 -5
- nucliadb/search/search/chat/fetcher.py +201 -0
- nucliadb/search/search/chat/images.py +6 -4
- nucliadb/search/search/chat/old_prompt.py +1375 -0
- nucliadb/search/search/chat/parser.py +510 -0
- nucliadb/search/search/chat/prompt.py +563 -615
- nucliadb/search/search/chat/query.py +449 -36
- nucliadb/search/search/chat/rpc.py +85 -0
- nucliadb/search/search/fetch.py +3 -4
- nucliadb/search/search/filters.py +8 -11
- nucliadb/search/search/find.py +33 -31
- nucliadb/search/search/find_merge.py +124 -331
- nucliadb/search/search/graph_strategy.py +14 -12
- nucliadb/search/search/hydrator/__init__.py +3 -152
- nucliadb/search/search/hydrator/fields.py +92 -50
- nucliadb/search/search/hydrator/images.py +7 -7
- nucliadb/search/search/hydrator/paragraphs.py +42 -26
- nucliadb/search/search/hydrator/resources.py +20 -16
- nucliadb/search/search/ingestion_agents.py +5 -5
- nucliadb/search/search/merge.py +90 -94
- nucliadb/search/search/metrics.py +10 -9
- nucliadb/search/search/paragraphs.py +7 -9
- nucliadb/search/search/predict_proxy.py +13 -9
- nucliadb/search/search/query.py +14 -86
- nucliadb/search/search/query_parser/fetcher.py +51 -82
- nucliadb/search/search/query_parser/models.py +19 -20
- nucliadb/search/search/query_parser/old_filters.py +20 -19
- nucliadb/search/search/query_parser/parsers/ask.py +4 -5
- nucliadb/search/search/query_parser/parsers/catalog.py +5 -6
- nucliadb/search/search/query_parser/parsers/common.py +5 -6
- nucliadb/search/search/query_parser/parsers/find.py +6 -26
- nucliadb/search/search/query_parser/parsers/graph.py +13 -23
- nucliadb/search/search/query_parser/parsers/retrieve.py +207 -0
- nucliadb/search/search/query_parser/parsers/search.py +15 -53
- nucliadb/search/search/query_parser/parsers/unit_retrieval.py +8 -29
- nucliadb/search/search/rank_fusion.py +18 -13
- nucliadb/search/search/rerankers.py +5 -6
- nucliadb/search/search/retrieval.py +300 -0
- nucliadb/search/search/summarize.py +5 -6
- nucliadb/search/search/utils.py +3 -4
- nucliadb/search/settings.py +1 -2
- nucliadb/standalone/api_router.py +1 -1
- nucliadb/standalone/app.py +4 -3
- nucliadb/standalone/auth.py +5 -6
- nucliadb/standalone/lifecycle.py +2 -2
- nucliadb/standalone/run.py +2 -4
- nucliadb/standalone/settings.py +5 -6
- nucliadb/standalone/versions.py +3 -4
- nucliadb/tasks/consumer.py +13 -8
- nucliadb/tasks/models.py +2 -1
- nucliadb/tasks/producer.py +3 -3
- nucliadb/tasks/retries.py +8 -7
- nucliadb/train/api/utils.py +1 -3
- nucliadb/train/api/v1/shards.py +1 -2
- nucliadb/train/api/v1/trainset.py +1 -2
- nucliadb/train/app.py +1 -1
- nucliadb/train/generator.py +4 -4
- nucliadb/train/generators/field_classifier.py +2 -2
- nucliadb/train/generators/field_streaming.py +6 -6
- nucliadb/train/generators/image_classifier.py +2 -2
- nucliadb/train/generators/paragraph_classifier.py +2 -2
- nucliadb/train/generators/paragraph_streaming.py +2 -2
- nucliadb/train/generators/question_answer_streaming.py +2 -2
- nucliadb/train/generators/sentence_classifier.py +2 -2
- nucliadb/train/generators/token_classifier.py +3 -2
- nucliadb/train/generators/utils.py +6 -5
- nucliadb/train/nodes.py +3 -3
- nucliadb/train/resource.py +6 -8
- nucliadb/train/settings.py +3 -4
- nucliadb/train/types.py +11 -11
- nucliadb/train/upload.py +3 -2
- nucliadb/train/uploader.py +1 -2
- nucliadb/train/utils.py +1 -2
- nucliadb/writer/api/v1/export_import.py +4 -1
- nucliadb/writer/api/v1/field.py +7 -11
- nucliadb/writer/api/v1/knowledgebox.py +3 -4
- nucliadb/writer/api/v1/resource.py +9 -20
- nucliadb/writer/api/v1/services.py +10 -132
- nucliadb/writer/api/v1/upload.py +73 -72
- nucliadb/writer/app.py +8 -2
- nucliadb/writer/resource/basic.py +12 -15
- nucliadb/writer/resource/field.py +7 -5
- nucliadb/writer/resource/origin.py +7 -0
- nucliadb/writer/settings.py +2 -3
- nucliadb/writer/tus/__init__.py +2 -3
- nucliadb/writer/tus/azure.py +1 -3
- nucliadb/writer/tus/dm.py +3 -3
- nucliadb/writer/tus/exceptions.py +3 -4
- nucliadb/writer/tus/gcs.py +5 -6
- nucliadb/writer/tus/s3.py +2 -3
- nucliadb/writer/tus/storage.py +3 -3
- {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/METADATA +9 -10
- nucliadb-6.10.0.post5705.dist-info/RECORD +410 -0
- nucliadb/common/datamanagers/entities.py +0 -139
- nucliadb-6.9.1.post5192.dist-info/RECORD +0 -392
- {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/WHEEL +0 -0
- {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1375 @@
|
|
|
1
|
+
# Copyright (C) 2021 Bosutech XXI S.L.
|
|
2
|
+
#
|
|
3
|
+
# nucliadb is offered under the AGPL v3.0 and as commercial software.
|
|
4
|
+
# For commercial licensing, contact us at info@nuclia.com.
|
|
5
|
+
#
|
|
6
|
+
# AGPL:
|
|
7
|
+
# This program is free software: you can redistribute it and/or modify
|
|
8
|
+
# it under the terms of the GNU Affero General Public License as
|
|
9
|
+
# published by the Free Software Foundation, either version 3 of the
|
|
10
|
+
# License, or (at your option) any later version.
|
|
11
|
+
#
|
|
12
|
+
# This program is distributed in the hope that it will be useful,
|
|
13
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
14
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
15
|
+
# GNU Affero General Public License for more details.
|
|
16
|
+
#
|
|
17
|
+
# You should have received a copy of the GNU Affero General Public License
|
|
18
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
|
+
#
|
|
20
|
+
import asyncio
|
|
21
|
+
import copy
|
|
22
|
+
from collections import deque
|
|
23
|
+
from collections.abc import Sequence
|
|
24
|
+
from dataclasses import dataclass
|
|
25
|
+
from typing import Deque, cast
|
|
26
|
+
|
|
27
|
+
import yaml
|
|
28
|
+
from pydantic import BaseModel
|
|
29
|
+
|
|
30
|
+
from nucliadb.common import datamanagers
|
|
31
|
+
from nucliadb.common.ids import FIELD_TYPE_PB_TO_STR, FIELD_TYPE_STR_TO_PB, FieldId, ParagraphId
|
|
32
|
+
from nucliadb.common.maindb.utils import get_driver
|
|
33
|
+
from nucliadb.common.models_utils import from_proto
|
|
34
|
+
from nucliadb.ingest.fields.base import Field
|
|
35
|
+
from nucliadb.ingest.fields.conversation import Conversation
|
|
36
|
+
from nucliadb.ingest.fields.file import File
|
|
37
|
+
from nucliadb.ingest.orm.knowledgebox import KnowledgeBox as KnowledgeBoxORM
|
|
38
|
+
from nucliadb.search import logger
|
|
39
|
+
from nucliadb.search.search import cache
|
|
40
|
+
from nucliadb.search.search.chat.images import (
|
|
41
|
+
get_file_thumbnail_image,
|
|
42
|
+
get_page_image,
|
|
43
|
+
get_paragraph_image,
|
|
44
|
+
)
|
|
45
|
+
from nucliadb.search.search.metrics import Metrics
|
|
46
|
+
from nucliadb.search.search.paragraphs import get_paragraph_text
|
|
47
|
+
from nucliadb_models.labels import translate_alias_to_system_label
|
|
48
|
+
from nucliadb_models.metadata import Extra, Origin
|
|
49
|
+
from nucliadb_models.search import (
|
|
50
|
+
SCORE_TYPE,
|
|
51
|
+
AugmentedContext,
|
|
52
|
+
AugmentedTextBlock,
|
|
53
|
+
ConversationalStrategy,
|
|
54
|
+
FieldExtensionStrategy,
|
|
55
|
+
FindParagraph,
|
|
56
|
+
FullResourceStrategy,
|
|
57
|
+
HierarchyResourceStrategy,
|
|
58
|
+
Image,
|
|
59
|
+
ImageRagStrategy,
|
|
60
|
+
ImageRagStrategyName,
|
|
61
|
+
MetadataExtensionStrategy,
|
|
62
|
+
MetadataExtensionType,
|
|
63
|
+
NeighbouringParagraphsStrategy,
|
|
64
|
+
PageImageStrategy,
|
|
65
|
+
ParagraphImageStrategy,
|
|
66
|
+
PromptContext,
|
|
67
|
+
PromptContextImages,
|
|
68
|
+
PromptContextOrder,
|
|
69
|
+
RagStrategy,
|
|
70
|
+
RagStrategyName,
|
|
71
|
+
TableImageStrategy,
|
|
72
|
+
TextBlockAugmentationType,
|
|
73
|
+
TextPosition,
|
|
74
|
+
)
|
|
75
|
+
from nucliadb_protos import resources_pb2
|
|
76
|
+
from nucliadb_protos.resources_pb2 import ExtractedText, FieldComputedMetadata
|
|
77
|
+
from nucliadb_telemetry.metrics import Observer
|
|
78
|
+
from nucliadb_utils.asyncio_utils import ConcurrentRunner, run_concurrently
|
|
79
|
+
from nucliadb_utils.utilities import get_storage
|
|
80
|
+
|
|
81
|
+
MAX_RESOURCE_TASKS = 5
|
|
82
|
+
MAX_RESOURCE_FIELD_TASKS = 4
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
# Number of messages to pull after a match in a message
|
|
86
|
+
# The hope here is it will be enough to get the answer to the question.
|
|
87
|
+
CONVERSATION_MESSAGE_CONTEXT_EXPANSION = 15
|
|
88
|
+
|
|
89
|
+
TextBlockId = ParagraphId | FieldId
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class ParagraphIdNotFoundInExtractedMetadata(Exception):
|
|
93
|
+
pass
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class CappedPromptContext:
|
|
97
|
+
"""
|
|
98
|
+
Class to keep track of the size (in number of characters) of the prompt context
|
|
99
|
+
and automatically trim data that exceeds the limit when it's being set on the dictionary.
|
|
100
|
+
"""
|
|
101
|
+
|
|
102
|
+
def __init__(self, max_size: int | None):
|
|
103
|
+
self.output: PromptContext = {}
|
|
104
|
+
self.images: PromptContextImages = {}
|
|
105
|
+
self.max_size = max_size
|
|
106
|
+
|
|
107
|
+
def __setitem__(self, key: str, value: str) -> None:
|
|
108
|
+
self.output.__setitem__(key, value)
|
|
109
|
+
|
|
110
|
+
def __getitem__(self, key: str) -> str:
|
|
111
|
+
return self.output.__getitem__(key)
|
|
112
|
+
|
|
113
|
+
def __contains__(self, key: str) -> bool:
|
|
114
|
+
return key in self.output
|
|
115
|
+
|
|
116
|
+
def __delitem__(self, key: str) -> None:
|
|
117
|
+
try:
|
|
118
|
+
self.output.__delitem__(key)
|
|
119
|
+
except KeyError:
|
|
120
|
+
pass
|
|
121
|
+
|
|
122
|
+
def text_block_ids(self) -> list[str]:
|
|
123
|
+
return list(self.output.keys())
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def size(self) -> int:
|
|
127
|
+
"""
|
|
128
|
+
Returns the total size of the context in characters.
|
|
129
|
+
"""
|
|
130
|
+
return sum(len(text) for text in self.output.values())
|
|
131
|
+
|
|
132
|
+
def cap(self) -> dict[str, str]:
|
|
133
|
+
"""
|
|
134
|
+
This method will trim the context to the maximum size if it exceeds it.
|
|
135
|
+
It will remove text from the most recent entries first, until the size is below the limit.
|
|
136
|
+
"""
|
|
137
|
+
if self.max_size is None:
|
|
138
|
+
return self.output
|
|
139
|
+
|
|
140
|
+
if self.size <= self.max_size:
|
|
141
|
+
return self.output
|
|
142
|
+
|
|
143
|
+
logger.info("Removing text from context to fit within the max size limit")
|
|
144
|
+
# Iterate the dictionary in reverse order of insertion
|
|
145
|
+
for key in reversed(list(self.output.keys())):
|
|
146
|
+
current_size = self.size
|
|
147
|
+
if current_size <= self.max_size:
|
|
148
|
+
break
|
|
149
|
+
# Remove text from the value
|
|
150
|
+
text = self.output[key]
|
|
151
|
+
# If removing the whole text still keeps the total size above the limit, remove it
|
|
152
|
+
if current_size - len(text) >= self.max_size:
|
|
153
|
+
del self.output[key]
|
|
154
|
+
else:
|
|
155
|
+
# Otherwise, trim the text to fit within the limit
|
|
156
|
+
excess_size = current_size - self.max_size
|
|
157
|
+
if excess_size > 0:
|
|
158
|
+
trimmed_text = text[:-excess_size]
|
|
159
|
+
self.output[key] = trimmed_text
|
|
160
|
+
return self.output
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
async def get_next_conversation_messages(
|
|
164
|
+
*,
|
|
165
|
+
field_obj: Conversation,
|
|
166
|
+
page: int,
|
|
167
|
+
start_idx: int,
|
|
168
|
+
num_messages: int,
|
|
169
|
+
message_type: resources_pb2.Message.MessageType.ValueType | None = None,
|
|
170
|
+
msg_to: str | None = None,
|
|
171
|
+
) -> list[resources_pb2.Message]:
|
|
172
|
+
output = []
|
|
173
|
+
cmetadata = await field_obj.get_metadata()
|
|
174
|
+
for current_page in range(page, cmetadata.pages + 1):
|
|
175
|
+
conv = await field_obj.db_get_value(current_page)
|
|
176
|
+
for message in conv.messages[start_idx:]:
|
|
177
|
+
if message_type is not None and message.type != message_type: # pragma: no cover
|
|
178
|
+
continue
|
|
179
|
+
if msg_to is not None and msg_to not in message.to: # pragma: no cover
|
|
180
|
+
continue
|
|
181
|
+
output.append(message)
|
|
182
|
+
if len(output) >= num_messages:
|
|
183
|
+
return output
|
|
184
|
+
start_idx = 0
|
|
185
|
+
|
|
186
|
+
return output
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
async def find_conversation_message(
|
|
190
|
+
field_obj: Conversation, mident: str
|
|
191
|
+
) -> tuple[resources_pb2.Message | None, int, int]:
|
|
192
|
+
cmetadata = await field_obj.get_metadata()
|
|
193
|
+
for page in range(1, cmetadata.pages + 1):
|
|
194
|
+
conv = await field_obj.db_get_value(page)
|
|
195
|
+
for idx, message in enumerate(conv.messages):
|
|
196
|
+
if message.ident == mident:
|
|
197
|
+
return message, page, idx
|
|
198
|
+
return None, -1, -1
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
async def get_expanded_conversation_messages(
|
|
202
|
+
*,
|
|
203
|
+
kb: KnowledgeBoxORM,
|
|
204
|
+
rid: str,
|
|
205
|
+
field_id: str,
|
|
206
|
+
mident: str,
|
|
207
|
+
max_messages: int = CONVERSATION_MESSAGE_CONTEXT_EXPANSION,
|
|
208
|
+
) -> list[resources_pb2.Message]:
|
|
209
|
+
resource = await kb.get(rid)
|
|
210
|
+
if resource is None: # pragma: no cover
|
|
211
|
+
return []
|
|
212
|
+
field_obj: Conversation = await resource.get_field(field_id, FIELD_TYPE_STR_TO_PB["c"], load=True) # type: ignore
|
|
213
|
+
found_message, found_page, found_idx = await find_conversation_message(
|
|
214
|
+
field_obj=field_obj, mident=mident
|
|
215
|
+
)
|
|
216
|
+
if found_message is None: # pragma: no cover
|
|
217
|
+
return []
|
|
218
|
+
elif found_message.type == resources_pb2.Message.MessageType.QUESTION:
|
|
219
|
+
# only try to get answer if it was a question
|
|
220
|
+
return await get_next_conversation_messages(
|
|
221
|
+
field_obj=field_obj,
|
|
222
|
+
page=found_page,
|
|
223
|
+
start_idx=found_idx + 1,
|
|
224
|
+
num_messages=1,
|
|
225
|
+
message_type=resources_pb2.Message.MessageType.ANSWER,
|
|
226
|
+
)
|
|
227
|
+
else:
|
|
228
|
+
return await get_next_conversation_messages(
|
|
229
|
+
field_obj=field_obj,
|
|
230
|
+
page=found_page,
|
|
231
|
+
start_idx=found_idx + 1,
|
|
232
|
+
num_messages=max_messages,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
async def default_prompt_context(
|
|
237
|
+
context: CappedPromptContext,
|
|
238
|
+
kbid: str,
|
|
239
|
+
ordered_paragraphs: list[FindParagraph],
|
|
240
|
+
) -> None:
|
|
241
|
+
"""
|
|
242
|
+
- Updates context (which is an ordered dict of text_block_id -> context_text).
|
|
243
|
+
- text_block_id is typically the paragraph id, but has a special value for the
|
|
244
|
+
user context. (USER_CONTEXT_0, USER_CONTEXT_1, ...)
|
|
245
|
+
- Paragraphs are inserted in order of relevance, by increasing `order` field
|
|
246
|
+
of the find result paragraphs.
|
|
247
|
+
- User context is inserted first, in order of appearance.
|
|
248
|
+
- Using an dict prevents from duplicates pulled in through conversation expansion.
|
|
249
|
+
"""
|
|
250
|
+
# Sort retrieved paragraphs by decreasing order (most relevant first)
|
|
251
|
+
async with get_driver().ro_transaction() as txn:
|
|
252
|
+
storage = await get_storage()
|
|
253
|
+
kb = KnowledgeBoxORM(txn, storage, kbid)
|
|
254
|
+
for paragraph in ordered_paragraphs:
|
|
255
|
+
context[paragraph.id] = _clean_paragraph_text(paragraph)
|
|
256
|
+
|
|
257
|
+
# If the paragraph is a conversation and it matches semantically, we assume we
|
|
258
|
+
# have matched with the question, therefore try to include the answer to the
|
|
259
|
+
# context by pulling the next few messages of the conversation field
|
|
260
|
+
rid, field_type, field_id, mident = paragraph.id.split("/")[:4]
|
|
261
|
+
if field_type == "c" and paragraph.score_type in (
|
|
262
|
+
SCORE_TYPE.VECTOR,
|
|
263
|
+
SCORE_TYPE.BOTH,
|
|
264
|
+
):
|
|
265
|
+
expanded_msgs = await get_expanded_conversation_messages(
|
|
266
|
+
kb=kb, rid=rid, field_id=field_id, mident=mident
|
|
267
|
+
)
|
|
268
|
+
for msg in expanded_msgs:
|
|
269
|
+
text = msg.content.text.strip()
|
|
270
|
+
pid = f"{rid}/{field_type}/{field_id}/{msg.ident}/0-{len(msg.content.text)}"
|
|
271
|
+
context[pid] = text
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
async def full_resource_prompt_context(
|
|
275
|
+
context: CappedPromptContext,
|
|
276
|
+
kbid: str,
|
|
277
|
+
ordered_paragraphs: list[FindParagraph],
|
|
278
|
+
resource: str | None,
|
|
279
|
+
strategy: FullResourceStrategy,
|
|
280
|
+
metrics: Metrics,
|
|
281
|
+
augmented_context: AugmentedContext,
|
|
282
|
+
) -> None:
|
|
283
|
+
"""
|
|
284
|
+
Algorithm steps:
|
|
285
|
+
- Collect the list of resources in the results (in order of relevance).
|
|
286
|
+
- For each resource, collect the extracted text from all its fields and craft the context.
|
|
287
|
+
Arguments:
|
|
288
|
+
context: The context to be updated.
|
|
289
|
+
kbid: The knowledge box id.
|
|
290
|
+
ordered_paragraphs: The results of the retrieval (find) operation.
|
|
291
|
+
resource: The resource to be included in the context. This is used only when chatting with a specific resource with no retrieval.
|
|
292
|
+
strategy: strategy instance containing, for example, the number of full resources to include in the context.
|
|
293
|
+
"""
|
|
294
|
+
if resource is not None:
|
|
295
|
+
# The user has specified a resource to be included in the context.
|
|
296
|
+
ordered_resources = [resource]
|
|
297
|
+
else:
|
|
298
|
+
# Collect the list of resources in the results (in order of relevance).
|
|
299
|
+
ordered_resources = []
|
|
300
|
+
for paragraph in ordered_paragraphs:
|
|
301
|
+
resource_uuid = parse_text_block_id(paragraph.id).rid
|
|
302
|
+
if resource_uuid not in ordered_resources:
|
|
303
|
+
skip = False
|
|
304
|
+
if strategy.apply_to is not None:
|
|
305
|
+
# decide whether the resource should be extended or not
|
|
306
|
+
for label in strategy.apply_to.exclude:
|
|
307
|
+
skip = skip or (
|
|
308
|
+
translate_alias_to_system_label(label) in (paragraph.labels or [])
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
if not skip:
|
|
312
|
+
ordered_resources.append(resource_uuid)
|
|
313
|
+
|
|
314
|
+
# For each resource, collect the extracted text from all its fields.
|
|
315
|
+
resources_extracted_texts = await run_concurrently(
|
|
316
|
+
[
|
|
317
|
+
hydrate_resource_text(kbid, resource_uuid, max_concurrent_tasks=MAX_RESOURCE_FIELD_TASKS)
|
|
318
|
+
for resource_uuid in ordered_resources[: strategy.count]
|
|
319
|
+
],
|
|
320
|
+
max_concurrent=MAX_RESOURCE_TASKS,
|
|
321
|
+
)
|
|
322
|
+
added_fields = set()
|
|
323
|
+
for resource_extracted_texts in resources_extracted_texts:
|
|
324
|
+
if resource_extracted_texts is None:
|
|
325
|
+
continue
|
|
326
|
+
for field, extracted_text in resource_extracted_texts:
|
|
327
|
+
# First off, remove the text block ids from paragraphs that belong to
|
|
328
|
+
# the same field, as otherwise the context will be duplicated.
|
|
329
|
+
for tb_id in context.text_block_ids():
|
|
330
|
+
if tb_id.startswith(field.full()):
|
|
331
|
+
del context[tb_id]
|
|
332
|
+
# Add the extracted text of each field to the context.
|
|
333
|
+
context[field.full()] = extracted_text
|
|
334
|
+
augmented_context.fields[field.full()] = AugmentedTextBlock(
|
|
335
|
+
id=field.full(),
|
|
336
|
+
text=extracted_text,
|
|
337
|
+
augmentation_type=TextBlockAugmentationType.FULL_RESOURCE,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
added_fields.add(field.full())
|
|
341
|
+
|
|
342
|
+
metrics.set("full_resource_ops", len(added_fields))
|
|
343
|
+
|
|
344
|
+
if strategy.include_remaining_text_blocks:
|
|
345
|
+
for paragraph in ordered_paragraphs:
|
|
346
|
+
pid = cast(ParagraphId, parse_text_block_id(paragraph.id))
|
|
347
|
+
if pid.field_id.full() not in added_fields:
|
|
348
|
+
context[paragraph.id] = _clean_paragraph_text(paragraph)
|
|
349
|
+
|
|
350
|
+
|
|
351
|
+
async def extend_prompt_context_with_metadata(
|
|
352
|
+
context: CappedPromptContext,
|
|
353
|
+
kbid: str,
|
|
354
|
+
strategy: MetadataExtensionStrategy,
|
|
355
|
+
metrics: Metrics,
|
|
356
|
+
augmented_context: AugmentedContext,
|
|
357
|
+
) -> None:
|
|
358
|
+
text_block_ids: list[TextBlockId] = []
|
|
359
|
+
for text_block_id in context.text_block_ids():
|
|
360
|
+
try:
|
|
361
|
+
text_block_ids.append(parse_text_block_id(text_block_id))
|
|
362
|
+
except ValueError: # pragma: no cover
|
|
363
|
+
# Some text block ids are not paragraphs nor fields, so they are skipped
|
|
364
|
+
# (e.g. USER_CONTEXT_0, when the user provides extra context)
|
|
365
|
+
continue
|
|
366
|
+
if len(text_block_ids) == 0: # pragma: no cover
|
|
367
|
+
return
|
|
368
|
+
|
|
369
|
+
ops = 0
|
|
370
|
+
if MetadataExtensionType.ORIGIN in strategy.types:
|
|
371
|
+
ops += 1
|
|
372
|
+
await extend_prompt_context_with_origin_metadata(
|
|
373
|
+
context, kbid, text_block_ids, augmented_context
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
if MetadataExtensionType.CLASSIFICATION_LABELS in strategy.types:
|
|
377
|
+
ops += 1
|
|
378
|
+
await extend_prompt_context_with_classification_labels(
|
|
379
|
+
context, kbid, text_block_ids, augmented_context
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
if MetadataExtensionType.NERS in strategy.types:
|
|
383
|
+
ops += 1
|
|
384
|
+
await extend_prompt_context_with_ner(context, kbid, text_block_ids, augmented_context)
|
|
385
|
+
|
|
386
|
+
if MetadataExtensionType.EXTRA_METADATA in strategy.types:
|
|
387
|
+
ops += 1
|
|
388
|
+
await extend_prompt_context_with_extra_metadata(context, kbid, text_block_ids, augmented_context)
|
|
389
|
+
|
|
390
|
+
metrics.set("metadata_extension_ops", ops * len(text_block_ids))
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def parse_text_block_id(text_block_id: str) -> TextBlockId:
|
|
394
|
+
try:
|
|
395
|
+
# Typically, the text block id is a paragraph id
|
|
396
|
+
return ParagraphId.from_string(text_block_id)
|
|
397
|
+
except ValueError:
|
|
398
|
+
# When we're doing `full_resource` or `hierarchy` strategies,the text block id
|
|
399
|
+
# is a field id
|
|
400
|
+
return FieldId.from_string(text_block_id)
|
|
401
|
+
|
|
402
|
+
|
|
403
|
+
async def extend_prompt_context_with_origin_metadata(
|
|
404
|
+
context: CappedPromptContext,
|
|
405
|
+
kbid,
|
|
406
|
+
text_block_ids: list[TextBlockId],
|
|
407
|
+
augmented_context: AugmentedContext,
|
|
408
|
+
):
|
|
409
|
+
async def _get_origin(kbid: str, rid: str) -> tuple[str, Origin | None]:
|
|
410
|
+
origin = None
|
|
411
|
+
resource = await cache.get_resource(kbid, rid)
|
|
412
|
+
if resource is not None:
|
|
413
|
+
pb_origin = await resource.get_origin()
|
|
414
|
+
if pb_origin is not None:
|
|
415
|
+
origin = from_proto.origin(pb_origin)
|
|
416
|
+
return rid, origin
|
|
417
|
+
|
|
418
|
+
rids = {tb_id.rid for tb_id in text_block_ids}
|
|
419
|
+
origins = await run_concurrently([_get_origin(kbid, rid) for rid in rids])
|
|
420
|
+
rid_to_origin = {rid: origin for rid, origin in origins if origin is not None}
|
|
421
|
+
for tb_id in text_block_ids:
|
|
422
|
+
origin = rid_to_origin.get(tb_id.rid)
|
|
423
|
+
if origin is not None and tb_id.full() in context:
|
|
424
|
+
text = context.output.pop(tb_id.full())
|
|
425
|
+
extended_text = text + f"\n\nDOCUMENT METADATA AT ORIGIN:\n{to_yaml(origin)}"
|
|
426
|
+
context[tb_id.full()] = extended_text
|
|
427
|
+
augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
|
|
428
|
+
id=tb_id.full(),
|
|
429
|
+
text=extended_text,
|
|
430
|
+
parent=tb_id.full(),
|
|
431
|
+
augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
async def extend_prompt_context_with_classification_labels(
|
|
436
|
+
context: CappedPromptContext,
|
|
437
|
+
kbid: str,
|
|
438
|
+
text_block_ids: list[TextBlockId],
|
|
439
|
+
augmented_context: AugmentedContext,
|
|
440
|
+
):
|
|
441
|
+
async def _get_labels(kbid: str, _id: TextBlockId) -> tuple[TextBlockId, list[tuple[str, str]]]:
|
|
442
|
+
fid = _id if isinstance(_id, FieldId) else _id.field_id
|
|
443
|
+
labels = set()
|
|
444
|
+
resource = await cache.get_resource(kbid, fid.rid)
|
|
445
|
+
if resource is not None:
|
|
446
|
+
pb_basic = await resource.get_basic()
|
|
447
|
+
if pb_basic is not None:
|
|
448
|
+
# Add the classification labels of the resource
|
|
449
|
+
for classif in pb_basic.usermetadata.classifications:
|
|
450
|
+
labels.add((classif.labelset, classif.label))
|
|
451
|
+
# Add the classifications labels of the field
|
|
452
|
+
for fc in pb_basic.computedmetadata.field_classifications:
|
|
453
|
+
if fc.field.field == fid.key and fc.field.field_type == fid.pb_type:
|
|
454
|
+
for classif in fc.classifications:
|
|
455
|
+
if classif.cancelled_by_user: # pragma: no cover
|
|
456
|
+
continue
|
|
457
|
+
labels.add((classif.labelset, classif.label))
|
|
458
|
+
return _id, list(labels)
|
|
459
|
+
|
|
460
|
+
classif_labels = await run_concurrently([_get_labels(kbid, tb_id) for tb_id in text_block_ids])
|
|
461
|
+
tb_id_to_labels = {tb_id: labels for tb_id, labels in classif_labels if len(labels) > 0}
|
|
462
|
+
for tb_id in text_block_ids:
|
|
463
|
+
labels = tb_id_to_labels.get(tb_id)
|
|
464
|
+
if labels is not None and tb_id.full() in context:
|
|
465
|
+
text = context.output.pop(tb_id.full())
|
|
466
|
+
|
|
467
|
+
labels_text = "DOCUMENT CLASSIFICATION LABELS:"
|
|
468
|
+
for labelset, label in labels:
|
|
469
|
+
labels_text += f"\n - {label} ({labelset})"
|
|
470
|
+
extended_text = text + "\n\n" + labels_text
|
|
471
|
+
|
|
472
|
+
context[tb_id.full()] = extended_text
|
|
473
|
+
augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
|
|
474
|
+
id=tb_id.full(),
|
|
475
|
+
text=extended_text,
|
|
476
|
+
parent=tb_id.full(),
|
|
477
|
+
augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
async def extend_prompt_context_with_ner(
|
|
482
|
+
context: CappedPromptContext,
|
|
483
|
+
kbid: str,
|
|
484
|
+
text_block_ids: list[TextBlockId],
|
|
485
|
+
augmented_context: AugmentedContext,
|
|
486
|
+
):
|
|
487
|
+
async def _get_ners(kbid: str, _id: TextBlockId) -> tuple[TextBlockId, dict[str, set[str]]]:
|
|
488
|
+
fid = _id if isinstance(_id, FieldId) else _id.field_id
|
|
489
|
+
ners: dict[str, set[str]] = {}
|
|
490
|
+
resource = await cache.get_resource(kbid, fid.rid)
|
|
491
|
+
if resource is not None:
|
|
492
|
+
field = await resource.get_field(fid.key, fid.pb_type, load=False)
|
|
493
|
+
fcm = await field.get_field_metadata()
|
|
494
|
+
if fcm is not None:
|
|
495
|
+
# Data Augmentation + Processor entities
|
|
496
|
+
for (
|
|
497
|
+
data_aumgentation_task_id,
|
|
498
|
+
entities_wrapper,
|
|
499
|
+
) in fcm.metadata.entities.items():
|
|
500
|
+
for entity in entities_wrapper.entities:
|
|
501
|
+
ners.setdefault(entity.label, set()).add(entity.text)
|
|
502
|
+
# Legacy processor entities
|
|
503
|
+
# TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message
|
|
504
|
+
for token, family in fcm.metadata.ner.items():
|
|
505
|
+
ners.setdefault(family, set()).add(token)
|
|
506
|
+
return _id, ners
|
|
507
|
+
|
|
508
|
+
nerss = await run_concurrently([_get_ners(kbid, tb_id) for tb_id in text_block_ids])
|
|
509
|
+
tb_id_to_ners = {tb_id: ners for tb_id, ners in nerss if len(ners) > 0}
|
|
510
|
+
for tb_id in text_block_ids:
|
|
511
|
+
ners = tb_id_to_ners.get(tb_id)
|
|
512
|
+
if ners is not None and tb_id.full() in context:
|
|
513
|
+
text = context.output.pop(tb_id.full())
|
|
514
|
+
|
|
515
|
+
ners_text = "DOCUMENT NAMED ENTITIES (NERs):"
|
|
516
|
+
for family, tokens in ners.items():
|
|
517
|
+
ners_text += f"\n - {family}:"
|
|
518
|
+
for token in sorted(list(tokens)):
|
|
519
|
+
ners_text += f"\n - {token}"
|
|
520
|
+
|
|
521
|
+
extended_text = text + "\n\n" + ners_text
|
|
522
|
+
|
|
523
|
+
context[tb_id.full()] = extended_text
|
|
524
|
+
augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
|
|
525
|
+
id=tb_id.full(),
|
|
526
|
+
text=extended_text,
|
|
527
|
+
parent=tb_id.full(),
|
|
528
|
+
augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
|
|
529
|
+
)
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
async def extend_prompt_context_with_extra_metadata(
|
|
533
|
+
context: CappedPromptContext,
|
|
534
|
+
kbid: str,
|
|
535
|
+
text_block_ids: list[TextBlockId],
|
|
536
|
+
augmented_context: AugmentedContext,
|
|
537
|
+
):
|
|
538
|
+
async def _get_extra(kbid: str, rid: str) -> tuple[str, Extra | None]:
|
|
539
|
+
extra = None
|
|
540
|
+
resource = await cache.get_resource(kbid, rid)
|
|
541
|
+
if resource is not None:
|
|
542
|
+
pb_extra = await resource.get_extra()
|
|
543
|
+
if pb_extra is not None:
|
|
544
|
+
extra = from_proto.extra(pb_extra)
|
|
545
|
+
return rid, extra
|
|
546
|
+
|
|
547
|
+
rids = {tb_id.rid for tb_id in text_block_ids}
|
|
548
|
+
extras = await run_concurrently([_get_extra(kbid, rid) for rid in rids])
|
|
549
|
+
rid_to_extra = {rid: extra for rid, extra in extras if extra is not None}
|
|
550
|
+
for tb_id in text_block_ids:
|
|
551
|
+
extra = rid_to_extra.get(tb_id.rid)
|
|
552
|
+
if extra is not None and tb_id.full() in context:
|
|
553
|
+
text = context.output.pop(tb_id.full())
|
|
554
|
+
extended_text = text + f"\n\nDOCUMENT EXTRA METADATA:\n{to_yaml(extra)}"
|
|
555
|
+
context[tb_id.full()] = extended_text
|
|
556
|
+
augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
|
|
557
|
+
id=tb_id.full(),
|
|
558
|
+
text=extended_text,
|
|
559
|
+
parent=tb_id.full(),
|
|
560
|
+
augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
def to_yaml(obj: BaseModel) -> str:
|
|
565
|
+
return yaml.dump(
|
|
566
|
+
obj.model_dump(exclude_none=True, exclude_defaults=True, exclude_unset=True),
|
|
567
|
+
default_flow_style=False,
|
|
568
|
+
indent=2,
|
|
569
|
+
sort_keys=True,
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
async def field_extension_prompt_context(
|
|
574
|
+
context: CappedPromptContext,
|
|
575
|
+
kbid: str,
|
|
576
|
+
ordered_paragraphs: list[FindParagraph],
|
|
577
|
+
strategy: FieldExtensionStrategy,
|
|
578
|
+
metrics: Metrics,
|
|
579
|
+
augmented_context: AugmentedContext,
|
|
580
|
+
) -> None:
|
|
581
|
+
"""
|
|
582
|
+
Algorithm steps:
|
|
583
|
+
- Collect the list of resources in the results (in order of relevance).
|
|
584
|
+
- For each resource, collect the extracted text from all its fields.
|
|
585
|
+
- Add the extracted text of each field to the beginning of the context.
|
|
586
|
+
- Add the extracted text of each paragraph to the end of the context.
|
|
587
|
+
"""
|
|
588
|
+
ordered_resources = []
|
|
589
|
+
for paragraph in ordered_paragraphs:
|
|
590
|
+
resource_uuid = ParagraphId.from_string(paragraph.id).rid
|
|
591
|
+
if resource_uuid not in ordered_resources:
|
|
592
|
+
ordered_resources.append(resource_uuid)
|
|
593
|
+
|
|
594
|
+
extend_field_ids = await get_matching_field_ids(kbid, ordered_resources, strategy)
|
|
595
|
+
tasks = [hydrate_field_text(kbid, fid) for fid in extend_field_ids]
|
|
596
|
+
field_extracted_texts = await run_concurrently(tasks)
|
|
597
|
+
|
|
598
|
+
metrics.set("field_extension_ops", len(field_extracted_texts))
|
|
599
|
+
|
|
600
|
+
for result in field_extracted_texts:
|
|
601
|
+
if result is None: # pragma: no cover
|
|
602
|
+
continue
|
|
603
|
+
field, extracted_text = result
|
|
604
|
+
# First off, remove the text block ids from paragraphs that belong to
|
|
605
|
+
# the same field, as otherwise the context will be duplicated.
|
|
606
|
+
for tb_id in context.text_block_ids():
|
|
607
|
+
if tb_id.startswith(field.full()):
|
|
608
|
+
del context[tb_id]
|
|
609
|
+
# Add the extracted text of each field to the beginning of the context.
|
|
610
|
+
if field.full() not in context:
|
|
611
|
+
context[field.full()] = extracted_text
|
|
612
|
+
augmented_context.fields[field.full()] = AugmentedTextBlock(
|
|
613
|
+
id=field.full(),
|
|
614
|
+
text=extracted_text,
|
|
615
|
+
augmentation_type=TextBlockAugmentationType.FIELD_EXTENSION,
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
# Add the extracted text of each paragraph to the end of the context.
|
|
619
|
+
for paragraph in ordered_paragraphs:
|
|
620
|
+
if paragraph.id not in context:
|
|
621
|
+
context[paragraph.id] = _clean_paragraph_text(paragraph)
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
async def get_matching_field_ids(
|
|
625
|
+
kbid: str, ordered_resources: list[str], strategy: FieldExtensionStrategy
|
|
626
|
+
) -> list[FieldId]:
|
|
627
|
+
extend_field_ids: list[FieldId] = []
|
|
628
|
+
# Fetch the extracted texts of the specified fields for each resource
|
|
629
|
+
for resource_uuid in ordered_resources:
|
|
630
|
+
for field_id in strategy.fields:
|
|
631
|
+
try:
|
|
632
|
+
fid = FieldId.from_string(f"{resource_uuid}/{field_id.strip('/')}")
|
|
633
|
+
extend_field_ids.append(fid)
|
|
634
|
+
except ValueError: # pragma: no cover
|
|
635
|
+
# Invalid field id, skiping
|
|
636
|
+
continue
|
|
637
|
+
if len(strategy.data_augmentation_field_prefixes) > 0:
|
|
638
|
+
for resource_uuid in ordered_resources:
|
|
639
|
+
all_field_ids = await datamanagers.atomic.resources.get_all_field_ids(
|
|
640
|
+
kbid=kbid, rid=resource_uuid, for_update=False
|
|
641
|
+
)
|
|
642
|
+
if all_field_ids is None:
|
|
643
|
+
continue
|
|
644
|
+
for fieldid in all_field_ids.fields:
|
|
645
|
+
# Generated fields are always text fields starting with "da-"
|
|
646
|
+
if any(
|
|
647
|
+
(
|
|
648
|
+
fieldid.field_type == resources_pb2.FieldType.TEXT
|
|
649
|
+
and fieldid.field.startswith(f"da-{prefix}-")
|
|
650
|
+
)
|
|
651
|
+
for prefix in strategy.data_augmentation_field_prefixes
|
|
652
|
+
):
|
|
653
|
+
extend_field_ids.append(
|
|
654
|
+
FieldId.from_pb(
|
|
655
|
+
rid=resource_uuid, field_type=fieldid.field_type, key=fieldid.field
|
|
656
|
+
)
|
|
657
|
+
)
|
|
658
|
+
return extend_field_ids
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
async def get_orm_field(kbid: str, field_id: FieldId) -> Field | None:
|
|
662
|
+
resource = await cache.get_resource(kbid, field_id.rid)
|
|
663
|
+
if resource is None: # pragma: no cover
|
|
664
|
+
return None
|
|
665
|
+
return await resource.get_field(key=field_id.key, type=field_id.pb_type, load=False)
|
|
666
|
+
|
|
667
|
+
|
|
668
|
+
async def neighbouring_paragraphs_prompt_context(
|
|
669
|
+
context: CappedPromptContext,
|
|
670
|
+
kbid: str,
|
|
671
|
+
ordered_text_blocks: list[FindParagraph],
|
|
672
|
+
strategy: NeighbouringParagraphsStrategy,
|
|
673
|
+
metrics: Metrics,
|
|
674
|
+
augmented_context: AugmentedContext,
|
|
675
|
+
) -> None:
|
|
676
|
+
"""
|
|
677
|
+
This function will get the paragraph texts and then craft a context with the neighbouring paragraphs of the
|
|
678
|
+
paragraphs in the ordered_paragraphs list.
|
|
679
|
+
"""
|
|
680
|
+
retrieved_paragraphs_ids = [
|
|
681
|
+
ParagraphId.from_string(text_block.id) for text_block in ordered_text_blocks
|
|
682
|
+
]
|
|
683
|
+
unique_field_ids = list({pid.field_id for pid in retrieved_paragraphs_ids})
|
|
684
|
+
|
|
685
|
+
# Get extracted texts and metadatas for all fields
|
|
686
|
+
fm_ops = []
|
|
687
|
+
et_ops = []
|
|
688
|
+
for field_id in unique_field_ids:
|
|
689
|
+
field = await get_orm_field(kbid, field_id)
|
|
690
|
+
if field is None:
|
|
691
|
+
continue
|
|
692
|
+
fm_ops.append(asyncio.create_task(field.get_field_metadata()))
|
|
693
|
+
et_ops.append(asyncio.create_task(field.get_extracted_text()))
|
|
694
|
+
|
|
695
|
+
field_metadatas: dict[FieldId, FieldComputedMetadata] = {
|
|
696
|
+
fid: fm for fid, fm in zip(unique_field_ids, await asyncio.gather(*fm_ops)) if fm is not None
|
|
697
|
+
}
|
|
698
|
+
extracted_texts: dict[FieldId, ExtractedText] = {
|
|
699
|
+
fid: et for fid, et in zip(unique_field_ids, await asyncio.gather(*et_ops)) if et is not None
|
|
700
|
+
}
|
|
701
|
+
|
|
702
|
+
def _get_paragraph_text(extracted_text: ExtractedText, pid: ParagraphId) -> str:
|
|
703
|
+
if pid.field_id.subfield_id:
|
|
704
|
+
text = extracted_text.split_text.get(pid.field_id.subfield_id) or ""
|
|
705
|
+
else:
|
|
706
|
+
text = extracted_text.text
|
|
707
|
+
return text[pid.paragraph_start : pid.paragraph_end]
|
|
708
|
+
|
|
709
|
+
for pid in retrieved_paragraphs_ids:
|
|
710
|
+
# Add the retrieved paragraph first
|
|
711
|
+
field_extracted_text = extracted_texts.get(pid.field_id, None)
|
|
712
|
+
if field_extracted_text is None:
|
|
713
|
+
continue
|
|
714
|
+
ptext = _get_paragraph_text(field_extracted_text, pid)
|
|
715
|
+
if ptext and pid.full() not in context:
|
|
716
|
+
context[pid.full()] = ptext
|
|
717
|
+
|
|
718
|
+
# Now add the neighbouring paragraphs
|
|
719
|
+
field_extracted_metadata = field_metadatas.get(pid.field_id, None)
|
|
720
|
+
if field_extracted_metadata is None:
|
|
721
|
+
continue
|
|
722
|
+
|
|
723
|
+
field_pids = [
|
|
724
|
+
ParagraphId(
|
|
725
|
+
field_id=pid.field_id,
|
|
726
|
+
paragraph_start=p.start,
|
|
727
|
+
paragraph_end=p.end,
|
|
728
|
+
)
|
|
729
|
+
for p in field_extracted_metadata.metadata.paragraphs
|
|
730
|
+
]
|
|
731
|
+
try:
|
|
732
|
+
index = field_pids.index(pid)
|
|
733
|
+
except ValueError:
|
|
734
|
+
continue
|
|
735
|
+
|
|
736
|
+
for neighbour_index in get_neighbouring_indices(
|
|
737
|
+
index=index,
|
|
738
|
+
before=strategy.before,
|
|
739
|
+
after=strategy.after,
|
|
740
|
+
field_pids=field_pids,
|
|
741
|
+
):
|
|
742
|
+
if neighbour_index == index:
|
|
743
|
+
# Already handled above
|
|
744
|
+
continue
|
|
745
|
+
try:
|
|
746
|
+
npid = field_pids[neighbour_index]
|
|
747
|
+
except IndexError:
|
|
748
|
+
continue
|
|
749
|
+
if npid in retrieved_paragraphs_ids or npid.full() in context:
|
|
750
|
+
# Already added
|
|
751
|
+
continue
|
|
752
|
+
ptext = _get_paragraph_text(field_extracted_text, npid)
|
|
753
|
+
if not ptext:
|
|
754
|
+
continue
|
|
755
|
+
context[npid.full()] = ptext
|
|
756
|
+
augmented_context.paragraphs[npid.full()] = AugmentedTextBlock(
|
|
757
|
+
id=npid.full(),
|
|
758
|
+
text=ptext,
|
|
759
|
+
position=get_text_position(npid, neighbour_index, field_extracted_metadata),
|
|
760
|
+
parent=pid.full(),
|
|
761
|
+
augmentation_type=TextBlockAugmentationType.NEIGHBOURING_PARAGRAPHS,
|
|
762
|
+
)
|
|
763
|
+
|
|
764
|
+
metrics.set("neighbouring_paragraphs_ops", len(augmented_context.paragraphs))
|
|
765
|
+
|
|
766
|
+
|
|
767
|
+
def get_text_position(
|
|
768
|
+
paragraph_id: ParagraphId, index: int, field_metadata: FieldComputedMetadata
|
|
769
|
+
) -> TextPosition | None:
|
|
770
|
+
if paragraph_id.field_id.subfield_id:
|
|
771
|
+
metadata = field_metadata.split_metadata[paragraph_id.field_id.subfield_id]
|
|
772
|
+
else:
|
|
773
|
+
metadata = field_metadata.metadata
|
|
774
|
+
try:
|
|
775
|
+
pmetadata = metadata.paragraphs[index]
|
|
776
|
+
except IndexError:
|
|
777
|
+
return None
|
|
778
|
+
page_number = None
|
|
779
|
+
if pmetadata.HasField("page"):
|
|
780
|
+
page_number = pmetadata.page.page
|
|
781
|
+
return TextPosition(
|
|
782
|
+
page_number=page_number,
|
|
783
|
+
index=index,
|
|
784
|
+
start=pmetadata.start,
|
|
785
|
+
end=pmetadata.end,
|
|
786
|
+
start_seconds=list(pmetadata.start_seconds),
|
|
787
|
+
end_seconds=list(pmetadata.end_seconds),
|
|
788
|
+
)
|
|
789
|
+
|
|
790
|
+
|
|
791
|
+
def get_neighbouring_indices(
|
|
792
|
+
index: int, before: int, after: int, field_pids: list[ParagraphId]
|
|
793
|
+
) -> list[int]:
|
|
794
|
+
lb_index = max(0, index - before)
|
|
795
|
+
ub_index = min(len(field_pids), index + after + 1)
|
|
796
|
+
return list(range(lb_index, index)) + list(range(index + 1, ub_index))
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
async def conversation_prompt_context(
|
|
800
|
+
context: CappedPromptContext,
|
|
801
|
+
kbid: str,
|
|
802
|
+
ordered_paragraphs: list[FindParagraph],
|
|
803
|
+
strategy: ConversationalStrategy,
|
|
804
|
+
visual_llm: bool,
|
|
805
|
+
metrics: Metrics,
|
|
806
|
+
augmented_context: AugmentedContext,
|
|
807
|
+
):
|
|
808
|
+
analyzed_fields: list[str] = []
|
|
809
|
+
ops = 0
|
|
810
|
+
async with get_driver().ro_transaction() as txn:
|
|
811
|
+
storage = await get_storage()
|
|
812
|
+
kb = KnowledgeBoxORM(txn, storage, kbid)
|
|
813
|
+
for paragraph in ordered_paragraphs:
|
|
814
|
+
if paragraph.id not in context:
|
|
815
|
+
context[paragraph.id] = _clean_paragraph_text(paragraph)
|
|
816
|
+
|
|
817
|
+
# If the paragraph is a conversation and it matches semantically, we assume we
|
|
818
|
+
# have matched with the question, therefore try to include the answer to the
|
|
819
|
+
# context by pulling the next few messages of the conversation field
|
|
820
|
+
rid, field_type, field_id, mident = paragraph.id.split("/")[:4]
|
|
821
|
+
if field_type == "c" and paragraph.score_type in (
|
|
822
|
+
SCORE_TYPE.VECTOR,
|
|
823
|
+
SCORE_TYPE.BOTH,
|
|
824
|
+
SCORE_TYPE.BM25,
|
|
825
|
+
):
|
|
826
|
+
field_unique_id = "-".join([rid, field_type, field_id])
|
|
827
|
+
if field_unique_id in analyzed_fields:
|
|
828
|
+
continue
|
|
829
|
+
resource = await kb.get(rid)
|
|
830
|
+
if resource is None: # pragma: no cover
|
|
831
|
+
continue
|
|
832
|
+
|
|
833
|
+
field_obj: Conversation = await resource.get_field(
|
|
834
|
+
field_id, FIELD_TYPE_STR_TO_PB["c"], load=True
|
|
835
|
+
) # type: ignore
|
|
836
|
+
cmetadata = await field_obj.get_metadata()
|
|
837
|
+
|
|
838
|
+
attachments: list[resources_pb2.FieldRef] = []
|
|
839
|
+
if strategy.full:
|
|
840
|
+
ops += 5
|
|
841
|
+
extracted_text = await field_obj.get_extracted_text()
|
|
842
|
+
for current_page in range(1, cmetadata.pages + 1):
|
|
843
|
+
conv = await field_obj.db_get_value(current_page)
|
|
844
|
+
|
|
845
|
+
for message in conv.messages:
|
|
846
|
+
ident = message.ident
|
|
847
|
+
if extracted_text is not None:
|
|
848
|
+
text = extracted_text.split_text.get(ident, message.content.text.strip())
|
|
849
|
+
else:
|
|
850
|
+
text = message.content.text.strip()
|
|
851
|
+
pid = f"{rid}/{field_type}/{field_id}/{ident}/0-{len(text)}"
|
|
852
|
+
attachments.extend(message.content.attachments_fields)
|
|
853
|
+
if pid in context:
|
|
854
|
+
continue
|
|
855
|
+
context[pid] = text
|
|
856
|
+
augmented_context.paragraphs[pid] = AugmentedTextBlock(
|
|
857
|
+
id=pid,
|
|
858
|
+
text=text,
|
|
859
|
+
parent=paragraph.id,
|
|
860
|
+
augmentation_type=TextBlockAugmentationType.CONVERSATION,
|
|
861
|
+
)
|
|
862
|
+
else:
|
|
863
|
+
# Add first message
|
|
864
|
+
extracted_text = await field_obj.get_extracted_text()
|
|
865
|
+
first_page = await field_obj.db_get_value()
|
|
866
|
+
if len(first_page.messages) > 0:
|
|
867
|
+
message = first_page.messages[0]
|
|
868
|
+
ident = message.ident
|
|
869
|
+
if extracted_text is not None:
|
|
870
|
+
text = extracted_text.split_text.get(ident, message.content.text.strip())
|
|
871
|
+
else:
|
|
872
|
+
text = message.content.text.strip()
|
|
873
|
+
attachments.extend(message.content.attachments_fields)
|
|
874
|
+
pid = f"{rid}/{field_type}/{field_id}/{ident}/0-{len(text)}"
|
|
875
|
+
if pid in context:
|
|
876
|
+
continue
|
|
877
|
+
context[pid] = text
|
|
878
|
+
augmented_context.paragraphs[pid] = AugmentedTextBlock(
|
|
879
|
+
id=pid,
|
|
880
|
+
text=text,
|
|
881
|
+
parent=paragraph.id,
|
|
882
|
+
augmentation_type=TextBlockAugmentationType.CONVERSATION,
|
|
883
|
+
)
|
|
884
|
+
|
|
885
|
+
messages: Deque[resources_pb2.Message] = deque(maxlen=strategy.max_messages)
|
|
886
|
+
|
|
887
|
+
pending = -1
|
|
888
|
+
for page in range(1, cmetadata.pages + 1):
|
|
889
|
+
# Collect the messages with the window asked by the user arround the match paragraph
|
|
890
|
+
conv = await field_obj.db_get_value(page)
|
|
891
|
+
for message in conv.messages:
|
|
892
|
+
messages.append(message)
|
|
893
|
+
if pending > 0:
|
|
894
|
+
pending -= 1
|
|
895
|
+
if message.ident == mident:
|
|
896
|
+
pending = (strategy.max_messages - 1) // 2
|
|
897
|
+
if pending == 0:
|
|
898
|
+
break
|
|
899
|
+
if pending == 0:
|
|
900
|
+
break
|
|
901
|
+
|
|
902
|
+
for message in messages:
|
|
903
|
+
ops += 1
|
|
904
|
+
text = message.content.text.strip()
|
|
905
|
+
attachments.extend(message.content.attachments_fields)
|
|
906
|
+
pid = f"{rid}/{field_type}/{field_id}/{message.ident}/0-{len(message.content.text)}"
|
|
907
|
+
if pid in context:
|
|
908
|
+
continue
|
|
909
|
+
context[pid] = text
|
|
910
|
+
augmented_context.paragraphs[pid] = AugmentedTextBlock(
|
|
911
|
+
id=pid,
|
|
912
|
+
text=text,
|
|
913
|
+
parent=paragraph.id,
|
|
914
|
+
augmentation_type=TextBlockAugmentationType.CONVERSATION,
|
|
915
|
+
)
|
|
916
|
+
|
|
917
|
+
if strategy.attachments_text:
|
|
918
|
+
# add on the context the images if vlm enabled
|
|
919
|
+
for attachment in attachments:
|
|
920
|
+
ops += 1
|
|
921
|
+
field: File = await resource.get_field(
|
|
922
|
+
attachment.field_id, attachment.field_type, load=True
|
|
923
|
+
) # type: ignore
|
|
924
|
+
extracted_text = await field.get_extracted_text()
|
|
925
|
+
if extracted_text is not None:
|
|
926
|
+
attachment_field_type = FIELD_TYPE_PB_TO_STR[attachment.field_type]
|
|
927
|
+
pid = f"{rid}/{attachment_field_type}/{attachment.field_id}/0-{len(extracted_text.text)}"
|
|
928
|
+
if pid in context:
|
|
929
|
+
continue
|
|
930
|
+
text = f"Attachment {attachment.field_id}: {extracted_text.text}\n\n"
|
|
931
|
+
context[pid] = text
|
|
932
|
+
augmented_context.paragraphs[pid] = AugmentedTextBlock(
|
|
933
|
+
id=pid,
|
|
934
|
+
text=text,
|
|
935
|
+
parent=paragraph.id,
|
|
936
|
+
augmentation_type=TextBlockAugmentationType.CONVERSATION,
|
|
937
|
+
)
|
|
938
|
+
|
|
939
|
+
if strategy.attachments_images and visual_llm:
|
|
940
|
+
for attachment in attachments:
|
|
941
|
+
ops += 1
|
|
942
|
+
file_field: File = await resource.get_field(
|
|
943
|
+
attachment.field_id, attachment.field_type, load=True
|
|
944
|
+
) # type: ignore
|
|
945
|
+
image = await get_file_thumbnail_image(file_field)
|
|
946
|
+
if image is not None:
|
|
947
|
+
pid = f"{rid}/f/{attachment.field_id}/0-0"
|
|
948
|
+
context.images[pid] = image
|
|
949
|
+
|
|
950
|
+
analyzed_fields.append(field_unique_id)
|
|
951
|
+
metrics.set("conversation_ops", ops)
|
|
952
|
+
|
|
953
|
+
|
|
954
|
+
async def hierarchy_prompt_context(
|
|
955
|
+
context: CappedPromptContext,
|
|
956
|
+
kbid: str,
|
|
957
|
+
ordered_paragraphs: list[FindParagraph],
|
|
958
|
+
strategy: HierarchyResourceStrategy,
|
|
959
|
+
metrics: Metrics,
|
|
960
|
+
augmented_context: AugmentedContext,
|
|
961
|
+
) -> None:
|
|
962
|
+
"""
|
|
963
|
+
This function will get the paragraph texts (possibly with extra characters, if extra_characters > 0) and then
|
|
964
|
+
craft a context with all paragraphs of the same resource grouped together. Moreover, on each group of paragraphs,
|
|
965
|
+
it includes the resource title and summary so that the LLM can have a better understanding of the context.
|
|
966
|
+
"""
|
|
967
|
+
paragraphs_extra_characters = max(strategy.count, 0)
|
|
968
|
+
# Make a copy of the ordered paragraphs to avoid modifying the original list, which is returned
|
|
969
|
+
# in the response to the user
|
|
970
|
+
ordered_paragraphs_copy = copy.deepcopy(ordered_paragraphs)
|
|
971
|
+
resources: dict[str, ExtraCharsParagraph] = {}
|
|
972
|
+
|
|
973
|
+
# Iterate paragraphs to get extended text
|
|
974
|
+
for paragraph in ordered_paragraphs_copy:
|
|
975
|
+
paragraph_id = ParagraphId.from_string(paragraph.id)
|
|
976
|
+
extended_paragraph_text = paragraph.text
|
|
977
|
+
if paragraphs_extra_characters > 0:
|
|
978
|
+
extended_paragraph_id = ParagraphId(
|
|
979
|
+
field_id=paragraph_id.field_id,
|
|
980
|
+
paragraph_start=paragraph_id.paragraph_start,
|
|
981
|
+
paragraph_end=paragraph_id.paragraph_end + paragraphs_extra_characters,
|
|
982
|
+
)
|
|
983
|
+
extended_paragraph_text = await get_paragraph_text(
|
|
984
|
+
kbid=kbid,
|
|
985
|
+
paragraph_id=extended_paragraph_id,
|
|
986
|
+
log_on_missing_field=True,
|
|
987
|
+
)
|
|
988
|
+
rid = paragraph_id.rid
|
|
989
|
+
if rid not in resources:
|
|
990
|
+
# Get the title and the summary of the resource
|
|
991
|
+
title_text = await get_paragraph_text(
|
|
992
|
+
kbid=kbid,
|
|
993
|
+
paragraph_id=ParagraphId(
|
|
994
|
+
field_id=FieldId(
|
|
995
|
+
rid=rid,
|
|
996
|
+
type="a",
|
|
997
|
+
key="title",
|
|
998
|
+
),
|
|
999
|
+
paragraph_start=0,
|
|
1000
|
+
paragraph_end=500,
|
|
1001
|
+
),
|
|
1002
|
+
log_on_missing_field=False,
|
|
1003
|
+
)
|
|
1004
|
+
summary_text = await get_paragraph_text(
|
|
1005
|
+
kbid=kbid,
|
|
1006
|
+
paragraph_id=ParagraphId(
|
|
1007
|
+
field_id=FieldId(
|
|
1008
|
+
rid=rid,
|
|
1009
|
+
type="a",
|
|
1010
|
+
key="summary",
|
|
1011
|
+
),
|
|
1012
|
+
paragraph_start=0,
|
|
1013
|
+
paragraph_end=1000,
|
|
1014
|
+
),
|
|
1015
|
+
log_on_missing_field=False,
|
|
1016
|
+
)
|
|
1017
|
+
resources[rid] = ExtraCharsParagraph(
|
|
1018
|
+
title=title_text,
|
|
1019
|
+
summary=summary_text,
|
|
1020
|
+
paragraphs=[(paragraph, extended_paragraph_text)],
|
|
1021
|
+
)
|
|
1022
|
+
else:
|
|
1023
|
+
resources[rid].paragraphs.append((paragraph, extended_paragraph_text))
|
|
1024
|
+
|
|
1025
|
+
metrics.set("hierarchy_ops", len(resources))
|
|
1026
|
+
augmented_paragraphs = set()
|
|
1027
|
+
|
|
1028
|
+
# Modify the first paragraph of each resource to include the title and summary of the resource, as well as the
|
|
1029
|
+
# extended paragraph text of all the paragraphs in the resource.
|
|
1030
|
+
for values in resources.values():
|
|
1031
|
+
title_text = values.title
|
|
1032
|
+
summary_text = values.summary
|
|
1033
|
+
first_paragraph = None
|
|
1034
|
+
text_with_hierarchy = ""
|
|
1035
|
+
for paragraph, extended_paragraph_text in values.paragraphs:
|
|
1036
|
+
if first_paragraph is None:
|
|
1037
|
+
first_paragraph = paragraph
|
|
1038
|
+
text_with_hierarchy += "\n EXTRACTED BLOCK: \n " + extended_paragraph_text + " \n\n "
|
|
1039
|
+
# All paragraphs of the resource are cleared except the first one, which will be the
|
|
1040
|
+
# one containing the whole hierarchy information
|
|
1041
|
+
paragraph.text = ""
|
|
1042
|
+
|
|
1043
|
+
if first_paragraph is not None:
|
|
1044
|
+
# The first paragraph is the only one holding the hierarchy information
|
|
1045
|
+
first_paragraph.text = f"DOCUMENT: {title_text} \n SUMMARY: {summary_text} \n RESOURCE CONTENT: {text_with_hierarchy}"
|
|
1046
|
+
augmented_paragraphs.add(first_paragraph.id)
|
|
1047
|
+
|
|
1048
|
+
# Now that the paragraphs have been modified, we can add them to the context
|
|
1049
|
+
for paragraph in ordered_paragraphs_copy:
|
|
1050
|
+
if paragraph.text == "":
|
|
1051
|
+
# Skip paragraphs that were cleared in the hierarchy expansion
|
|
1052
|
+
continue
|
|
1053
|
+
paragraph_text = _clean_paragraph_text(paragraph)
|
|
1054
|
+
context[paragraph.id] = paragraph_text
|
|
1055
|
+
if paragraph.id in augmented_paragraphs:
|
|
1056
|
+
pid = ParagraphId.from_string(paragraph.id)
|
|
1057
|
+
augmented_context.paragraphs[pid.full()] = AugmentedTextBlock(
|
|
1058
|
+
id=pid.full(), text=paragraph_text, augmentation_type=TextBlockAugmentationType.HIERARCHY
|
|
1059
|
+
)
|
|
1060
|
+
return
|
|
1061
|
+
|
|
1062
|
+
|
|
1063
|
+
class PromptContextBuilder:
|
|
1064
|
+
"""
|
|
1065
|
+
Builds the context for the LLM prompt.
|
|
1066
|
+
"""
|
|
1067
|
+
|
|
1068
|
+
def __init__(
|
|
1069
|
+
self,
|
|
1070
|
+
kbid: str,
|
|
1071
|
+
ordered_paragraphs: list[FindParagraph],
|
|
1072
|
+
resource: str | None = None,
|
|
1073
|
+
user_context: list[str] | None = None,
|
|
1074
|
+
user_image_context: list[Image] | None = None,
|
|
1075
|
+
strategies: Sequence[RagStrategy] | None = None,
|
|
1076
|
+
image_strategies: Sequence[ImageRagStrategy] | None = None,
|
|
1077
|
+
max_context_characters: int | None = None,
|
|
1078
|
+
visual_llm: bool = False,
|
|
1079
|
+
query_image: Image | None = None,
|
|
1080
|
+
metrics: Metrics = Metrics("prompt_context_builder"),
|
|
1081
|
+
):
|
|
1082
|
+
self.kbid = kbid
|
|
1083
|
+
self.ordered_paragraphs = ordered_paragraphs
|
|
1084
|
+
self.resource = resource
|
|
1085
|
+
self.user_context = user_context
|
|
1086
|
+
self.user_image_context = user_image_context
|
|
1087
|
+
self.strategies = strategies
|
|
1088
|
+
self.image_strategies = image_strategies
|
|
1089
|
+
self.max_context_characters = max_context_characters
|
|
1090
|
+
self.visual_llm = visual_llm
|
|
1091
|
+
self.metrics = metrics
|
|
1092
|
+
self.query_image = query_image
|
|
1093
|
+
self.augmented_context = AugmentedContext(paragraphs={}, fields={})
|
|
1094
|
+
|
|
1095
|
+
def prepend_user_context(self, context: CappedPromptContext):
|
|
1096
|
+
# Chat extra context passed by the user is the most important, therefore
|
|
1097
|
+
# it is added first, followed by the found text blocks in order of relevance
|
|
1098
|
+
for i, text_block in enumerate(self.user_context or []):
|
|
1099
|
+
context[f"USER_CONTEXT_{i}"] = text_block
|
|
1100
|
+
# Add the query image as part of the image context
|
|
1101
|
+
if self.query_image is not None:
|
|
1102
|
+
context.images["QUERY_IMAGE"] = self.query_image
|
|
1103
|
+
else:
|
|
1104
|
+
for i, image in enumerate(self.user_image_context or []):
|
|
1105
|
+
context.images[f"USER_IMAGE_CONTEXT_{i}"] = image
|
|
1106
|
+
|
|
1107
|
+
async def build(
|
|
1108
|
+
self,
|
|
1109
|
+
) -> tuple[PromptContext, PromptContextOrder, PromptContextImages, AugmentedContext]:
|
|
1110
|
+
ccontext = CappedPromptContext(max_size=self.max_context_characters)
|
|
1111
|
+
self.prepend_user_context(ccontext)
|
|
1112
|
+
await self._build_context(ccontext)
|
|
1113
|
+
if self.visual_llm and not self.query_image:
|
|
1114
|
+
await self._build_context_images(ccontext)
|
|
1115
|
+
context = ccontext.cap()
|
|
1116
|
+
context_images = ccontext.images
|
|
1117
|
+
context_order = {text_block_id: order for order, text_block_id in enumerate(context.keys())}
|
|
1118
|
+
return context, context_order, context_images, self.augmented_context
|
|
1119
|
+
|
|
1120
|
+
async def _build_context_images(self, context: CappedPromptContext) -> None:
|
|
1121
|
+
ops = 0
|
|
1122
|
+
if self.image_strategies is None or len(self.image_strategies) == 0:
|
|
1123
|
+
# Nothing to do
|
|
1124
|
+
return
|
|
1125
|
+
page_image_strategy: PageImageStrategy | None = None
|
|
1126
|
+
max_page_images = 5
|
|
1127
|
+
table_image_strategy: TableImageStrategy | None = None
|
|
1128
|
+
paragraph_image_strategy: ParagraphImageStrategy | None = None
|
|
1129
|
+
for strategy in self.image_strategies:
|
|
1130
|
+
if strategy.name == ImageRagStrategyName.PAGE_IMAGE:
|
|
1131
|
+
if page_image_strategy is None:
|
|
1132
|
+
page_image_strategy = cast(PageImageStrategy, strategy)
|
|
1133
|
+
if page_image_strategy.count is not None:
|
|
1134
|
+
max_page_images = page_image_strategy.count
|
|
1135
|
+
elif strategy.name == ImageRagStrategyName.TABLES:
|
|
1136
|
+
if table_image_strategy is None:
|
|
1137
|
+
table_image_strategy = cast(TableImageStrategy, strategy)
|
|
1138
|
+
elif strategy.name == ImageRagStrategyName.PARAGRAPH_IMAGE:
|
|
1139
|
+
if paragraph_image_strategy is None:
|
|
1140
|
+
paragraph_image_strategy = cast(ParagraphImageStrategy, strategy)
|
|
1141
|
+
else: # pragma: no cover
|
|
1142
|
+
logger.warning(
|
|
1143
|
+
"Unknown image strategy",
|
|
1144
|
+
extra={"strategy": strategy.name, "kbid": self.kbid},
|
|
1145
|
+
)
|
|
1146
|
+
page_images_added = 0
|
|
1147
|
+
for paragraph in self.ordered_paragraphs:
|
|
1148
|
+
pid = ParagraphId.from_string(paragraph.id)
|
|
1149
|
+
paragraph_page_number = get_paragraph_page_number(paragraph)
|
|
1150
|
+
if (
|
|
1151
|
+
page_image_strategy is not None
|
|
1152
|
+
and page_images_added < max_page_images
|
|
1153
|
+
and paragraph_page_number is not None
|
|
1154
|
+
):
|
|
1155
|
+
# page_image_id: rid/f/myfield/0
|
|
1156
|
+
page_image_id = "/".join([pid.field_id.full(), str(paragraph_page_number)])
|
|
1157
|
+
if page_image_id not in context.images:
|
|
1158
|
+
image = await get_page_image(self.kbid, pid, paragraph_page_number)
|
|
1159
|
+
if image is not None:
|
|
1160
|
+
ops += 1
|
|
1161
|
+
context.images[page_image_id] = image
|
|
1162
|
+
page_images_added += 1
|
|
1163
|
+
else:
|
|
1164
|
+
logger.warning(
|
|
1165
|
+
f"Could not retrieve image for paragraph from storage",
|
|
1166
|
+
extra={
|
|
1167
|
+
"kbid": self.kbid,
|
|
1168
|
+
"paragraph": pid.full(),
|
|
1169
|
+
"page_number": paragraph_page_number,
|
|
1170
|
+
},
|
|
1171
|
+
)
|
|
1172
|
+
|
|
1173
|
+
add_table = table_image_strategy is not None and paragraph.is_a_table
|
|
1174
|
+
add_paragraph = paragraph_image_strategy is not None and not paragraph.is_a_table
|
|
1175
|
+
if (add_table or add_paragraph) and (
|
|
1176
|
+
paragraph.reference is not None and paragraph.reference != ""
|
|
1177
|
+
):
|
|
1178
|
+
pimage = await get_paragraph_image(self.kbid, pid, paragraph.reference)
|
|
1179
|
+
if pimage is not None:
|
|
1180
|
+
ops += 1
|
|
1181
|
+
context.images[paragraph.id] = pimage
|
|
1182
|
+
else:
|
|
1183
|
+
logger.warning(
|
|
1184
|
+
f"Could not retrieve image for paragraph from storage",
|
|
1185
|
+
extra={
|
|
1186
|
+
"kbid": self.kbid,
|
|
1187
|
+
"paragraph": pid.full(),
|
|
1188
|
+
"reference": paragraph.reference,
|
|
1189
|
+
},
|
|
1190
|
+
)
|
|
1191
|
+
self.metrics.set("image_ops", ops)
|
|
1192
|
+
|
|
1193
|
+
async def _build_context(self, context: CappedPromptContext) -> None:
|
|
1194
|
+
if self.strategies is None or len(self.strategies) == 0:
|
|
1195
|
+
# When no strategy is specified, use the default one
|
|
1196
|
+
await default_prompt_context(context, self.kbid, self.ordered_paragraphs)
|
|
1197
|
+
return
|
|
1198
|
+
else:
|
|
1199
|
+
# Add the paragraphs to the context and then apply the strategies
|
|
1200
|
+
for paragraph in self.ordered_paragraphs:
|
|
1201
|
+
context[paragraph.id] = _clean_paragraph_text(paragraph)
|
|
1202
|
+
|
|
1203
|
+
strategies_not_handled_here = [
|
|
1204
|
+
RagStrategyName.PREQUERIES,
|
|
1205
|
+
RagStrategyName.GRAPH,
|
|
1206
|
+
]
|
|
1207
|
+
|
|
1208
|
+
full_resource: FullResourceStrategy | None = None
|
|
1209
|
+
hierarchy: HierarchyResourceStrategy | None = None
|
|
1210
|
+
neighbouring_paragraphs: NeighbouringParagraphsStrategy | None = None
|
|
1211
|
+
field_extension: FieldExtensionStrategy | None = None
|
|
1212
|
+
metadata_extension: MetadataExtensionStrategy | None = None
|
|
1213
|
+
conversational_strategy: ConversationalStrategy | None = None
|
|
1214
|
+
for strategy in self.strategies:
|
|
1215
|
+
if strategy.name == RagStrategyName.FIELD_EXTENSION:
|
|
1216
|
+
field_extension = cast(FieldExtensionStrategy, strategy)
|
|
1217
|
+
elif strategy.name == RagStrategyName.CONVERSATION:
|
|
1218
|
+
conversational_strategy = cast(ConversationalStrategy, strategy)
|
|
1219
|
+
elif strategy.name == RagStrategyName.FULL_RESOURCE:
|
|
1220
|
+
full_resource = cast(FullResourceStrategy, strategy)
|
|
1221
|
+
if self.resource: # pragma: no cover
|
|
1222
|
+
# When the retrieval is scoped to a specific resource
|
|
1223
|
+
# the full resource strategy only includes that resource
|
|
1224
|
+
full_resource.count = 1
|
|
1225
|
+
elif strategy.name == RagStrategyName.HIERARCHY:
|
|
1226
|
+
hierarchy = cast(HierarchyResourceStrategy, strategy)
|
|
1227
|
+
elif strategy.name == RagStrategyName.NEIGHBOURING_PARAGRAPHS:
|
|
1228
|
+
neighbouring_paragraphs = cast(NeighbouringParagraphsStrategy, strategy)
|
|
1229
|
+
elif strategy.name == RagStrategyName.METADATA_EXTENSION:
|
|
1230
|
+
metadata_extension = cast(MetadataExtensionStrategy, strategy)
|
|
1231
|
+
elif strategy.name not in strategies_not_handled_here: # pragma: no cover
|
|
1232
|
+
# Prequeries and graph are not handled here
|
|
1233
|
+
logger.warning(
|
|
1234
|
+
"Unknown rag strategy",
|
|
1235
|
+
extra={"strategy": strategy.name, "kbid": self.kbid},
|
|
1236
|
+
)
|
|
1237
|
+
|
|
1238
|
+
if full_resource:
|
|
1239
|
+
# When full resoure is enabled, only metadata extension is allowed.
|
|
1240
|
+
await full_resource_prompt_context(
|
|
1241
|
+
context,
|
|
1242
|
+
self.kbid,
|
|
1243
|
+
self.ordered_paragraphs,
|
|
1244
|
+
self.resource,
|
|
1245
|
+
full_resource,
|
|
1246
|
+
self.metrics,
|
|
1247
|
+
self.augmented_context,
|
|
1248
|
+
)
|
|
1249
|
+
if metadata_extension:
|
|
1250
|
+
await extend_prompt_context_with_metadata(
|
|
1251
|
+
context,
|
|
1252
|
+
self.kbid,
|
|
1253
|
+
metadata_extension,
|
|
1254
|
+
self.metrics,
|
|
1255
|
+
self.augmented_context,
|
|
1256
|
+
)
|
|
1257
|
+
return
|
|
1258
|
+
|
|
1259
|
+
if hierarchy:
|
|
1260
|
+
await hierarchy_prompt_context(
|
|
1261
|
+
context,
|
|
1262
|
+
self.kbid,
|
|
1263
|
+
self.ordered_paragraphs,
|
|
1264
|
+
hierarchy,
|
|
1265
|
+
self.metrics,
|
|
1266
|
+
self.augmented_context,
|
|
1267
|
+
)
|
|
1268
|
+
if neighbouring_paragraphs:
|
|
1269
|
+
await neighbouring_paragraphs_prompt_context(
|
|
1270
|
+
context,
|
|
1271
|
+
self.kbid,
|
|
1272
|
+
self.ordered_paragraphs,
|
|
1273
|
+
neighbouring_paragraphs,
|
|
1274
|
+
self.metrics,
|
|
1275
|
+
self.augmented_context,
|
|
1276
|
+
)
|
|
1277
|
+
if field_extension:
|
|
1278
|
+
await field_extension_prompt_context(
|
|
1279
|
+
context,
|
|
1280
|
+
self.kbid,
|
|
1281
|
+
self.ordered_paragraphs,
|
|
1282
|
+
field_extension,
|
|
1283
|
+
self.metrics,
|
|
1284
|
+
self.augmented_context,
|
|
1285
|
+
)
|
|
1286
|
+
if conversational_strategy:
|
|
1287
|
+
await conversation_prompt_context(
|
|
1288
|
+
context,
|
|
1289
|
+
self.kbid,
|
|
1290
|
+
self.ordered_paragraphs,
|
|
1291
|
+
conversational_strategy,
|
|
1292
|
+
self.visual_llm,
|
|
1293
|
+
self.metrics,
|
|
1294
|
+
self.augmented_context,
|
|
1295
|
+
)
|
|
1296
|
+
if metadata_extension:
|
|
1297
|
+
await extend_prompt_context_with_metadata(
|
|
1298
|
+
context,
|
|
1299
|
+
self.kbid,
|
|
1300
|
+
metadata_extension,
|
|
1301
|
+
self.metrics,
|
|
1302
|
+
self.augmented_context,
|
|
1303
|
+
)
|
|
1304
|
+
|
|
1305
|
+
|
|
1306
|
+
def get_paragraph_page_number(paragraph: FindParagraph) -> int | None:
|
|
1307
|
+
if not paragraph.page_with_visual:
|
|
1308
|
+
return None
|
|
1309
|
+
if paragraph.position is None:
|
|
1310
|
+
return None
|
|
1311
|
+
return paragraph.position.page_number
|
|
1312
|
+
|
|
1313
|
+
|
|
1314
|
+
@dataclass
|
|
1315
|
+
class ExtraCharsParagraph:
|
|
1316
|
+
title: str
|
|
1317
|
+
summary: str
|
|
1318
|
+
paragraphs: list[tuple[FindParagraph, str]]
|
|
1319
|
+
|
|
1320
|
+
|
|
1321
|
+
def _clean_paragraph_text(paragraph: FindParagraph) -> str:
|
|
1322
|
+
text = paragraph.text.strip()
|
|
1323
|
+
# Do not send highlight marks on prompt context
|
|
1324
|
+
text = text.replace("<mark>", "").replace("</mark>", "")
|
|
1325
|
+
return text
|
|
1326
|
+
|
|
1327
|
+
|
|
1328
|
+
# COPY from hydrator/__init__.py that has been refactored and removed
|
|
1329
|
+
|
|
1330
|
+
|
|
1331
|
+
hydrator_observer = Observer("hydrator", labels={"type": ""})
|
|
1332
|
+
|
|
1333
|
+
|
|
1334
|
+
@hydrator_observer.wrap({"type": "resource_text"})
|
|
1335
|
+
async def hydrate_resource_text(
|
|
1336
|
+
kbid: str, rid: str, *, max_concurrent_tasks: int
|
|
1337
|
+
) -> list[tuple[FieldId, str]]:
|
|
1338
|
+
resource = await cache.get_resource(kbid, rid)
|
|
1339
|
+
if resource is None: # pragma: no cover
|
|
1340
|
+
return []
|
|
1341
|
+
|
|
1342
|
+
# Schedule the extraction of the text of each field in the resource
|
|
1343
|
+
async with get_driver().ro_transaction() as txn:
|
|
1344
|
+
resource.txn = txn
|
|
1345
|
+
runner = ConcurrentRunner(max_tasks=max_concurrent_tasks)
|
|
1346
|
+
for field_type, field_key in await resource.get_fields(force=True):
|
|
1347
|
+
field_id = FieldId.from_pb(rid, field_type, field_key)
|
|
1348
|
+
runner.schedule(hydrate_field_text(kbid, field_id))
|
|
1349
|
+
|
|
1350
|
+
# Include the summary aswell
|
|
1351
|
+
runner.schedule(hydrate_field_text(kbid, FieldId(rid=rid, type="a", key="summary")))
|
|
1352
|
+
|
|
1353
|
+
# Wait for the results
|
|
1354
|
+
field_extracted_texts = await runner.wait()
|
|
1355
|
+
|
|
1356
|
+
return [text for text in field_extracted_texts if text is not None]
|
|
1357
|
+
|
|
1358
|
+
|
|
1359
|
+
@hydrator_observer.wrap({"type": "field_text"})
|
|
1360
|
+
async def hydrate_field_text(
|
|
1361
|
+
kbid: str,
|
|
1362
|
+
field_id: FieldId,
|
|
1363
|
+
) -> tuple[FieldId, str] | None:
|
|
1364
|
+
field = await cache.get_field(kbid, field_id)
|
|
1365
|
+
if field is None: # pragma: no cover
|
|
1366
|
+
return None
|
|
1367
|
+
|
|
1368
|
+
extracted_text_pb = await cache.get_field_extracted_text(field)
|
|
1369
|
+
if extracted_text_pb is None: # pragma: no cover
|
|
1370
|
+
return None
|
|
1371
|
+
|
|
1372
|
+
if field_id.subfield_id:
|
|
1373
|
+
return field_id, extracted_text_pb.split_text[field_id.subfield_id]
|
|
1374
|
+
else:
|
|
1375
|
+
return field_id, extracted_text_pb.text
|