nucliadb 6.7.2.post4874__py3-none-any.whl → 6.10.0.post5705__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migrations/0023_backfill_pg_catalog.py +8 -4
- migrations/0028_extracted_vectors_reference.py +1 -1
- migrations/0029_backfill_field_status.py +3 -4
- migrations/0032_remove_old_relations.py +2 -3
- migrations/0038_backfill_catalog_field_labels.py +8 -4
- migrations/0039_backfill_converation_splits_metadata.py +106 -0
- migrations/0040_migrate_search_configurations.py +79 -0
- migrations/0041_reindex_conversations.py +137 -0
- migrations/pg/0010_shards_index.py +34 -0
- nucliadb/search/api/v1/resource/utils.py → migrations/pg/0011_catalog_statistics.py +5 -6
- migrations/pg/0012_catalog_statistics_undo.py +26 -0
- nucliadb/backups/create.py +2 -15
- nucliadb/backups/restore.py +4 -15
- nucliadb/backups/tasks.py +4 -1
- nucliadb/common/back_pressure/cache.py +2 -3
- nucliadb/common/back_pressure/materializer.py +7 -13
- nucliadb/common/back_pressure/settings.py +6 -6
- nucliadb/common/back_pressure/utils.py +1 -0
- nucliadb/common/cache.py +9 -9
- nucliadb/common/catalog/__init__.py +79 -0
- nucliadb/common/catalog/dummy.py +36 -0
- nucliadb/common/catalog/interface.py +85 -0
- nucliadb/{search/search/pgcatalog.py → common/catalog/pg.py} +330 -232
- nucliadb/common/catalog/utils.py +56 -0
- nucliadb/common/cluster/manager.py +8 -23
- nucliadb/common/cluster/rebalance.py +484 -112
- nucliadb/common/cluster/rollover.py +36 -9
- nucliadb/common/cluster/settings.py +4 -9
- nucliadb/common/cluster/utils.py +34 -8
- nucliadb/common/context/__init__.py +7 -8
- nucliadb/common/context/fastapi.py +1 -2
- nucliadb/common/datamanagers/__init__.py +2 -4
- nucliadb/common/datamanagers/atomic.py +9 -2
- nucliadb/common/datamanagers/cluster.py +1 -2
- nucliadb/common/datamanagers/fields.py +3 -4
- nucliadb/common/datamanagers/kb.py +6 -6
- nucliadb/common/datamanagers/labels.py +2 -3
- nucliadb/common/datamanagers/resources.py +10 -33
- nucliadb/common/datamanagers/rollover.py +5 -7
- nucliadb/common/datamanagers/search_configurations.py +1 -2
- nucliadb/common/datamanagers/synonyms.py +1 -2
- nucliadb/common/datamanagers/utils.py +4 -4
- nucliadb/common/datamanagers/vectorsets.py +4 -4
- nucliadb/common/external_index_providers/base.py +32 -5
- nucliadb/common/external_index_providers/manager.py +5 -34
- nucliadb/common/external_index_providers/settings.py +1 -27
- nucliadb/common/filter_expression.py +129 -41
- nucliadb/common/http_clients/exceptions.py +8 -0
- nucliadb/common/http_clients/processing.py +16 -23
- nucliadb/common/http_clients/utils.py +3 -0
- nucliadb/common/ids.py +82 -58
- nucliadb/common/locking.py +1 -2
- nucliadb/common/maindb/driver.py +9 -8
- nucliadb/common/maindb/local.py +5 -5
- nucliadb/common/maindb/pg.py +9 -8
- nucliadb/common/nidx.py +22 -5
- nucliadb/common/vector_index_config.py +1 -1
- nucliadb/export_import/datamanager.py +4 -3
- nucliadb/export_import/exporter.py +11 -19
- nucliadb/export_import/importer.py +13 -6
- nucliadb/export_import/tasks.py +2 -0
- nucliadb/export_import/utils.py +6 -18
- nucliadb/health.py +2 -2
- nucliadb/ingest/app.py +8 -8
- nucliadb/ingest/consumer/consumer.py +8 -10
- nucliadb/ingest/consumer/pull.py +10 -8
- nucliadb/ingest/consumer/service.py +5 -30
- nucliadb/ingest/consumer/shard_creator.py +16 -5
- nucliadb/ingest/consumer/utils.py +1 -1
- nucliadb/ingest/fields/base.py +37 -49
- nucliadb/ingest/fields/conversation.py +55 -9
- nucliadb/ingest/fields/exceptions.py +1 -2
- nucliadb/ingest/fields/file.py +22 -8
- nucliadb/ingest/fields/link.py +7 -7
- nucliadb/ingest/fields/text.py +2 -3
- nucliadb/ingest/orm/brain_v2.py +89 -57
- nucliadb/ingest/orm/broker_message.py +2 -4
- nucliadb/ingest/orm/entities.py +10 -209
- nucliadb/ingest/orm/index_message.py +128 -113
- nucliadb/ingest/orm/knowledgebox.py +91 -59
- nucliadb/ingest/orm/processor/auditing.py +1 -3
- nucliadb/ingest/orm/processor/data_augmentation.py +1 -2
- nucliadb/ingest/orm/processor/processor.py +98 -153
- nucliadb/ingest/orm/processor/sequence_manager.py +1 -2
- nucliadb/ingest/orm/resource.py +82 -71
- nucliadb/ingest/orm/utils.py +1 -1
- nucliadb/ingest/partitions.py +12 -1
- nucliadb/ingest/processing.py +17 -17
- nucliadb/ingest/serialize.py +202 -145
- nucliadb/ingest/service/writer.py +15 -114
- nucliadb/ingest/settings.py +36 -15
- nucliadb/ingest/utils.py +1 -2
- nucliadb/learning_proxy.py +23 -26
- nucliadb/metrics_exporter.py +20 -6
- nucliadb/middleware/__init__.py +82 -1
- nucliadb/migrator/datamanager.py +4 -11
- nucliadb/migrator/migrator.py +1 -2
- nucliadb/migrator/models.py +1 -2
- nucliadb/migrator/settings.py +1 -2
- nucliadb/models/internal/augment.py +614 -0
- nucliadb/models/internal/processing.py +19 -19
- nucliadb/openapi.py +2 -2
- nucliadb/purge/__init__.py +3 -8
- nucliadb/purge/orphan_shards.py +1 -2
- nucliadb/reader/__init__.py +5 -0
- nucliadb/reader/api/models.py +6 -13
- nucliadb/reader/api/v1/download.py +59 -38
- nucliadb/reader/api/v1/export_import.py +4 -4
- nucliadb/reader/api/v1/knowledgebox.py +37 -9
- nucliadb/reader/api/v1/learning_config.py +33 -14
- nucliadb/reader/api/v1/resource.py +61 -9
- nucliadb/reader/api/v1/services.py +18 -14
- nucliadb/reader/app.py +3 -1
- nucliadb/reader/reader/notifications.py +1 -2
- nucliadb/search/api/v1/__init__.py +3 -0
- nucliadb/search/api/v1/ask.py +3 -4
- nucliadb/search/api/v1/augment.py +585 -0
- nucliadb/search/api/v1/catalog.py +15 -19
- nucliadb/search/api/v1/find.py +16 -22
- nucliadb/search/api/v1/hydrate.py +328 -0
- nucliadb/search/api/v1/knowledgebox.py +1 -2
- nucliadb/search/api/v1/predict_proxy.py +1 -2
- nucliadb/search/api/v1/resource/ask.py +28 -8
- nucliadb/search/api/v1/resource/ingestion_agents.py +5 -6
- nucliadb/search/api/v1/resource/search.py +9 -11
- nucliadb/search/api/v1/retrieve.py +130 -0
- nucliadb/search/api/v1/search.py +28 -32
- nucliadb/search/api/v1/suggest.py +11 -14
- nucliadb/search/api/v1/summarize.py +1 -2
- nucliadb/search/api/v1/utils.py +2 -2
- nucliadb/search/app.py +3 -2
- nucliadb/search/augmentor/__init__.py +21 -0
- nucliadb/search/augmentor/augmentor.py +232 -0
- nucliadb/search/augmentor/fields.py +704 -0
- nucliadb/search/augmentor/metrics.py +24 -0
- nucliadb/search/augmentor/paragraphs.py +334 -0
- nucliadb/search/augmentor/resources.py +238 -0
- nucliadb/search/augmentor/utils.py +33 -0
- nucliadb/search/lifecycle.py +3 -1
- nucliadb/search/predict.py +33 -19
- nucliadb/search/predict_models.py +8 -9
- nucliadb/search/requesters/utils.py +11 -10
- nucliadb/search/search/cache.py +19 -42
- nucliadb/search/search/chat/ask.py +131 -59
- nucliadb/search/search/chat/exceptions.py +3 -5
- nucliadb/search/search/chat/fetcher.py +201 -0
- nucliadb/search/search/chat/images.py +6 -4
- nucliadb/search/search/chat/old_prompt.py +1375 -0
- nucliadb/search/search/chat/parser.py +510 -0
- nucliadb/search/search/chat/prompt.py +563 -615
- nucliadb/search/search/chat/query.py +453 -32
- nucliadb/search/search/chat/rpc.py +85 -0
- nucliadb/search/search/fetch.py +3 -4
- nucliadb/search/search/filters.py +8 -11
- nucliadb/search/search/find.py +33 -31
- nucliadb/search/search/find_merge.py +124 -331
- nucliadb/search/search/graph_strategy.py +14 -12
- nucliadb/search/search/hydrator/__init__.py +49 -0
- nucliadb/search/search/hydrator/fields.py +217 -0
- nucliadb/search/search/hydrator/images.py +130 -0
- nucliadb/search/search/hydrator/paragraphs.py +323 -0
- nucliadb/search/search/hydrator/resources.py +60 -0
- nucliadb/search/search/ingestion_agents.py +5 -5
- nucliadb/search/search/merge.py +90 -94
- nucliadb/search/search/metrics.py +24 -7
- nucliadb/search/search/paragraphs.py +7 -9
- nucliadb/search/search/predict_proxy.py +44 -18
- nucliadb/search/search/query.py +14 -86
- nucliadb/search/search/query_parser/fetcher.py +51 -82
- nucliadb/search/search/query_parser/models.py +19 -48
- nucliadb/search/search/query_parser/old_filters.py +20 -19
- nucliadb/search/search/query_parser/parsers/ask.py +5 -6
- nucliadb/search/search/query_parser/parsers/catalog.py +7 -11
- nucliadb/search/search/query_parser/parsers/common.py +21 -13
- nucliadb/search/search/query_parser/parsers/find.py +6 -29
- nucliadb/search/search/query_parser/parsers/graph.py +18 -28
- nucliadb/search/search/query_parser/parsers/retrieve.py +207 -0
- nucliadb/search/search/query_parser/parsers/search.py +15 -56
- nucliadb/search/search/query_parser/parsers/unit_retrieval.py +8 -29
- nucliadb/search/search/rank_fusion.py +18 -13
- nucliadb/search/search/rerankers.py +6 -7
- nucliadb/search/search/retrieval.py +300 -0
- nucliadb/search/search/summarize.py +5 -6
- nucliadb/search/search/utils.py +3 -4
- nucliadb/search/settings.py +1 -2
- nucliadb/standalone/api_router.py +1 -1
- nucliadb/standalone/app.py +4 -3
- nucliadb/standalone/auth.py +5 -6
- nucliadb/standalone/lifecycle.py +2 -2
- nucliadb/standalone/run.py +5 -4
- nucliadb/standalone/settings.py +5 -6
- nucliadb/standalone/versions.py +3 -4
- nucliadb/tasks/consumer.py +13 -8
- nucliadb/tasks/models.py +2 -1
- nucliadb/tasks/producer.py +3 -3
- nucliadb/tasks/retries.py +8 -7
- nucliadb/train/api/utils.py +1 -3
- nucliadb/train/api/v1/shards.py +1 -2
- nucliadb/train/api/v1/trainset.py +1 -2
- nucliadb/train/app.py +1 -1
- nucliadb/train/generator.py +4 -4
- nucliadb/train/generators/field_classifier.py +2 -2
- nucliadb/train/generators/field_streaming.py +6 -6
- nucliadb/train/generators/image_classifier.py +2 -2
- nucliadb/train/generators/paragraph_classifier.py +2 -2
- nucliadb/train/generators/paragraph_streaming.py +2 -2
- nucliadb/train/generators/question_answer_streaming.py +2 -2
- nucliadb/train/generators/sentence_classifier.py +4 -10
- nucliadb/train/generators/token_classifier.py +3 -2
- nucliadb/train/generators/utils.py +6 -5
- nucliadb/train/nodes.py +3 -3
- nucliadb/train/resource.py +6 -8
- nucliadb/train/settings.py +3 -4
- nucliadb/train/types.py +11 -11
- nucliadb/train/upload.py +3 -2
- nucliadb/train/uploader.py +1 -2
- nucliadb/train/utils.py +1 -2
- nucliadb/writer/api/v1/export_import.py +4 -1
- nucliadb/writer/api/v1/field.py +15 -14
- nucliadb/writer/api/v1/knowledgebox.py +18 -56
- nucliadb/writer/api/v1/learning_config.py +5 -4
- nucliadb/writer/api/v1/resource.py +9 -20
- nucliadb/writer/api/v1/services.py +10 -132
- nucliadb/writer/api/v1/upload.py +73 -72
- nucliadb/writer/app.py +8 -2
- nucliadb/writer/resource/basic.py +12 -15
- nucliadb/writer/resource/field.py +43 -5
- nucliadb/writer/resource/origin.py +7 -0
- nucliadb/writer/settings.py +2 -3
- nucliadb/writer/tus/__init__.py +2 -3
- nucliadb/writer/tus/azure.py +5 -7
- nucliadb/writer/tus/dm.py +3 -3
- nucliadb/writer/tus/exceptions.py +3 -4
- nucliadb/writer/tus/gcs.py +15 -22
- nucliadb/writer/tus/s3.py +2 -3
- nucliadb/writer/tus/storage.py +3 -3
- {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/METADATA +10 -11
- nucliadb-6.10.0.post5705.dist-info/RECORD +410 -0
- nucliadb/common/datamanagers/entities.py +0 -139
- nucliadb/common/external_index_providers/pinecone.py +0 -894
- nucliadb/ingest/orm/processor/pgcatalog.py +0 -129
- nucliadb/search/search/hydrator.py +0 -197
- nucliadb-6.7.2.post4874.dist-info/RECORD +0 -383
- {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/WHEEL +0 -0
- {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/top_level.txt +0 -0
|
@@ -17,34 +17,36 @@
|
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
|
-
import asyncio
|
|
21
20
|
import copy
|
|
22
|
-
from collections import
|
|
21
|
+
from collections.abc import Sequence
|
|
23
22
|
from dataclasses import dataclass
|
|
24
|
-
from typing import
|
|
23
|
+
from typing import cast
|
|
25
24
|
|
|
26
25
|
import yaml
|
|
27
26
|
from pydantic import BaseModel
|
|
28
27
|
|
|
29
|
-
|
|
30
|
-
from nucliadb.common.
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
from nucliadb.ingest.fields.file import File
|
|
35
|
-
from nucliadb.ingest.orm.knowledgebox import KnowledgeBox as KnowledgeBoxORM
|
|
36
|
-
from nucliadb.search import logger
|
|
37
|
-
from nucliadb.search.search import cache
|
|
38
|
-
from nucliadb.search.search.chat.images import (
|
|
39
|
-
get_file_thumbnail_image,
|
|
40
|
-
get_page_image,
|
|
41
|
-
get_paragraph_image,
|
|
28
|
+
import nucliadb_models
|
|
29
|
+
from nucliadb.common.ids import (
|
|
30
|
+
FIELD_TYPE_STR_TO_NAME,
|
|
31
|
+
FieldId,
|
|
32
|
+
ParagraphId,
|
|
42
33
|
)
|
|
43
|
-
from nucliadb.search
|
|
34
|
+
from nucliadb.search import logger
|
|
35
|
+
from nucliadb.search.search.chat import rpc
|
|
44
36
|
from nucliadb.search.search.metrics import Metrics
|
|
45
|
-
from
|
|
37
|
+
from nucliadb_models.augment import (
|
|
38
|
+
AugmentedConversationField,
|
|
39
|
+
AugmentedField,
|
|
40
|
+
AugmentedFileField,
|
|
41
|
+
AugmentFields,
|
|
42
|
+
AugmentParagraph,
|
|
43
|
+
AugmentParagraphs,
|
|
44
|
+
AugmentRequest,
|
|
45
|
+
AugmentResourceFields,
|
|
46
|
+
AugmentResources,
|
|
47
|
+
)
|
|
48
|
+
from nucliadb_models.common import FieldTypeName
|
|
46
49
|
from nucliadb_models.labels import translate_alias_to_system_label
|
|
47
|
-
from nucliadb_models.metadata import Extra, Origin
|
|
48
50
|
from nucliadb_models.search import (
|
|
49
51
|
SCORE_TYPE,
|
|
50
52
|
AugmentedContext,
|
|
@@ -71,24 +73,9 @@ from nucliadb_models.search import (
|
|
|
71
73
|
TextBlockAugmentationType,
|
|
72
74
|
TextPosition,
|
|
73
75
|
)
|
|
74
|
-
from nucliadb_protos import
|
|
75
|
-
from nucliadb_protos.resources_pb2 import ExtractedText, FieldComputedMetadata
|
|
76
|
-
from nucliadb_utils.asyncio_utils import run_concurrently
|
|
77
|
-
from nucliadb_utils.utilities import get_storage
|
|
78
|
-
|
|
79
|
-
MAX_RESOURCE_TASKS = 5
|
|
80
|
-
MAX_RESOURCE_FIELD_TASKS = 4
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
# Number of messages to pull after a match in a message
|
|
84
|
-
# The hope here is it will be enough to get the answer to the question.
|
|
85
|
-
CONVERSATION_MESSAGE_CONTEXT_EXPANSION = 15
|
|
86
|
-
|
|
87
|
-
TextBlockId = Union[ParagraphId, FieldId]
|
|
76
|
+
from nucliadb_protos.resources_pb2 import FieldComputedMetadata
|
|
88
77
|
|
|
89
|
-
|
|
90
|
-
class ParagraphIdNotFoundInExtractedMetadata(Exception):
|
|
91
|
-
pass
|
|
78
|
+
TextBlockId = ParagraphId | FieldId
|
|
92
79
|
|
|
93
80
|
|
|
94
81
|
class CappedPromptContext:
|
|
@@ -97,7 +84,7 @@ class CappedPromptContext:
|
|
|
97
84
|
and automatically trim data that exceeds the limit when it's being set on the dictionary.
|
|
98
85
|
"""
|
|
99
86
|
|
|
100
|
-
def __init__(self, max_size:
|
|
87
|
+
def __init__(self, max_size: int | None):
|
|
101
88
|
self.output: PromptContext = {}
|
|
102
89
|
self.images: PromptContextImages = {}
|
|
103
90
|
self.max_size = max_size
|
|
@@ -158,79 +145,6 @@ class CappedPromptContext:
|
|
|
158
145
|
return self.output
|
|
159
146
|
|
|
160
147
|
|
|
161
|
-
async def get_next_conversation_messages(
|
|
162
|
-
*,
|
|
163
|
-
field_obj: Conversation,
|
|
164
|
-
page: int,
|
|
165
|
-
start_idx: int,
|
|
166
|
-
num_messages: int,
|
|
167
|
-
message_type: Optional[resources_pb2.Message.MessageType.ValueType] = None,
|
|
168
|
-
msg_to: Optional[str] = None,
|
|
169
|
-
) -> List[resources_pb2.Message]:
|
|
170
|
-
output = []
|
|
171
|
-
cmetadata = await field_obj.get_metadata()
|
|
172
|
-
for current_page in range(page, cmetadata.pages + 1):
|
|
173
|
-
conv = await field_obj.db_get_value(current_page)
|
|
174
|
-
for message in conv.messages[start_idx:]:
|
|
175
|
-
if message_type is not None and message.type != message_type: # pragma: no cover
|
|
176
|
-
continue
|
|
177
|
-
if msg_to is not None and msg_to not in message.to: # pragma: no cover
|
|
178
|
-
continue
|
|
179
|
-
output.append(message)
|
|
180
|
-
if len(output) >= num_messages:
|
|
181
|
-
return output
|
|
182
|
-
start_idx = 0
|
|
183
|
-
|
|
184
|
-
return output
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
async def find_conversation_message(
|
|
188
|
-
field_obj: Conversation, mident: str
|
|
189
|
-
) -> tuple[Optional[resources_pb2.Message], int, int]:
|
|
190
|
-
cmetadata = await field_obj.get_metadata()
|
|
191
|
-
for page in range(1, cmetadata.pages + 1):
|
|
192
|
-
conv = await field_obj.db_get_value(page)
|
|
193
|
-
for idx, message in enumerate(conv.messages):
|
|
194
|
-
if message.ident == mident:
|
|
195
|
-
return message, page, idx
|
|
196
|
-
return None, -1, -1
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
async def get_expanded_conversation_messages(
|
|
200
|
-
*,
|
|
201
|
-
kb: KnowledgeBoxORM,
|
|
202
|
-
rid: str,
|
|
203
|
-
field_id: str,
|
|
204
|
-
mident: str,
|
|
205
|
-
max_messages: int = CONVERSATION_MESSAGE_CONTEXT_EXPANSION,
|
|
206
|
-
) -> list[resources_pb2.Message]:
|
|
207
|
-
resource = await kb.get(rid)
|
|
208
|
-
if resource is None: # pragma: no cover
|
|
209
|
-
return []
|
|
210
|
-
field_obj: Conversation = await resource.get_field(field_id, FIELD_TYPE_STR_TO_PB["c"], load=True) # type: ignore
|
|
211
|
-
found_message, found_page, found_idx = await find_conversation_message(
|
|
212
|
-
field_obj=field_obj, mident=mident
|
|
213
|
-
)
|
|
214
|
-
if found_message is None: # pragma: no cover
|
|
215
|
-
return []
|
|
216
|
-
elif found_message.type == resources_pb2.Message.MessageType.QUESTION:
|
|
217
|
-
# only try to get answer if it was a question
|
|
218
|
-
return await get_next_conversation_messages(
|
|
219
|
-
field_obj=field_obj,
|
|
220
|
-
page=found_page,
|
|
221
|
-
start_idx=found_idx + 1,
|
|
222
|
-
num_messages=1,
|
|
223
|
-
message_type=resources_pb2.Message.MessageType.ANSWER,
|
|
224
|
-
)
|
|
225
|
-
else:
|
|
226
|
-
return await get_next_conversation_messages(
|
|
227
|
-
field_obj=field_obj,
|
|
228
|
-
page=found_page,
|
|
229
|
-
start_idx=found_idx + 1,
|
|
230
|
-
num_messages=max_messages,
|
|
231
|
-
)
|
|
232
|
-
|
|
233
|
-
|
|
234
148
|
async def default_prompt_context(
|
|
235
149
|
context: CappedPromptContext,
|
|
236
150
|
kbid: str,
|
|
@@ -245,35 +159,59 @@ async def default_prompt_context(
|
|
|
245
159
|
- User context is inserted first, in order of appearance.
|
|
246
160
|
- Using an dict prevents from duplicates pulled in through conversation expansion.
|
|
247
161
|
"""
|
|
248
|
-
# Sort retrieved paragraphs by decreasing order (most relevant first)
|
|
249
|
-
async with get_driver().ro_transaction() as txn:
|
|
250
|
-
storage = await get_storage()
|
|
251
|
-
kb = KnowledgeBoxORM(txn, storage, kbid)
|
|
252
|
-
for paragraph in ordered_paragraphs:
|
|
253
|
-
context[paragraph.id] = _clean_paragraph_text(paragraph)
|
|
254
162
|
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
163
|
+
conversations = []
|
|
164
|
+
|
|
165
|
+
for paragraph in ordered_paragraphs:
|
|
166
|
+
context[paragraph.id] = _clean_paragraph_text(paragraph)
|
|
167
|
+
|
|
168
|
+
# If the paragraph is a conversation and it matches semantically, we
|
|
169
|
+
# assume we have matched with the question, therefore try to include the
|
|
170
|
+
# answer to the context by pulling the next few messages of the
|
|
171
|
+
# conversation field
|
|
172
|
+
rid, field_type, field_id, mident = paragraph.id.split("/")[:4]
|
|
173
|
+
# FIXME: a semantic paragraph can have reranker score. Once we
|
|
174
|
+
# refactor and have access to the score history, we can fix this
|
|
175
|
+
if field_type == "c" and paragraph.score_type in (
|
|
176
|
+
SCORE_TYPE.VECTOR,
|
|
177
|
+
SCORE_TYPE.BOTH,
|
|
178
|
+
):
|
|
179
|
+
conversations.append(f"{rid}/{field_type}/{field_id}/{mident}")
|
|
180
|
+
|
|
181
|
+
augment = AugmentRequest(
|
|
182
|
+
fields=[
|
|
183
|
+
AugmentFields(
|
|
184
|
+
given=[id for id in conversations],
|
|
185
|
+
conversation_answer_or_messages_after=True,
|
|
186
|
+
),
|
|
187
|
+
]
|
|
188
|
+
)
|
|
189
|
+
augmented = await rpc.augment(kbid, augment)
|
|
190
|
+
|
|
191
|
+
for id in conversations:
|
|
192
|
+
conversation_id = FieldId.from_string(id)
|
|
193
|
+
|
|
194
|
+
augmented_field = augmented.fields.get(conversation_id.full_without_subfield())
|
|
195
|
+
if augmented_field is None or not isinstance(augmented_field, AugmentedConversationField):
|
|
196
|
+
continue
|
|
197
|
+
|
|
198
|
+
for message in augmented_field.messages or []:
|
|
199
|
+
if message.text is None:
|
|
200
|
+
continue
|
|
201
|
+
|
|
202
|
+
message_id = copy.copy(conversation_id)
|
|
203
|
+
message_id.subfield_id = message.ident
|
|
204
|
+
pid = ParagraphId(
|
|
205
|
+
field_id=message_id, paragraph_start=0, paragraph_end=len(message.text)
|
|
206
|
+
).full()
|
|
207
|
+
context[pid] = message.text
|
|
270
208
|
|
|
271
209
|
|
|
272
210
|
async def full_resource_prompt_context(
|
|
273
211
|
context: CappedPromptContext,
|
|
274
212
|
kbid: str,
|
|
275
213
|
ordered_paragraphs: list[FindParagraph],
|
|
276
|
-
|
|
214
|
+
rid: str | None,
|
|
277
215
|
strategy: FullResourceStrategy,
|
|
278
216
|
metrics: Metrics,
|
|
279
217
|
augmented_context: AugmentedContext,
|
|
@@ -288,16 +226,16 @@ async def full_resource_prompt_context(
|
|
|
288
226
|
ordered_paragraphs: The results of the retrieval (find) operation.
|
|
289
227
|
resource: The resource to be included in the context. This is used only when chatting with a specific resource with no retrieval.
|
|
290
228
|
strategy: strategy instance containing, for example, the number of full resources to include in the context.
|
|
291
|
-
"""
|
|
292
|
-
if
|
|
229
|
+
"""
|
|
230
|
+
if rid is not None:
|
|
293
231
|
# The user has specified a resource to be included in the context.
|
|
294
|
-
ordered_resources = [
|
|
232
|
+
ordered_resources = [rid]
|
|
295
233
|
else:
|
|
296
234
|
# Collect the list of resources in the results (in order of relevance).
|
|
297
235
|
ordered_resources = []
|
|
298
236
|
for paragraph in ordered_paragraphs:
|
|
299
|
-
|
|
300
|
-
if
|
|
237
|
+
rid = parse_text_block_id(paragraph.id).rid
|
|
238
|
+
if rid not in ordered_resources:
|
|
301
239
|
skip = False
|
|
302
240
|
if strategy.apply_to is not None:
|
|
303
241
|
# decide whether the resource should be extended or not
|
|
@@ -307,35 +245,62 @@ async def full_resource_prompt_context(
|
|
|
307
245
|
)
|
|
308
246
|
|
|
309
247
|
if not skip:
|
|
310
|
-
ordered_resources.append(
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
|
|
248
|
+
ordered_resources.append(rid)
|
|
249
|
+
# skip when we have enough resource ids
|
|
250
|
+
if strategy.count is not None and len(ordered_resources) > strategy.count:
|
|
251
|
+
break
|
|
252
|
+
|
|
253
|
+
ordered_resources = ordered_resources[: strategy.count]
|
|
254
|
+
|
|
255
|
+
# For each resource, collect the extracted text from all its fields and
|
|
256
|
+
# include the title and summary as well
|
|
257
|
+
augmented = await rpc.augment(
|
|
258
|
+
kbid,
|
|
259
|
+
AugmentRequest(
|
|
260
|
+
resources=[
|
|
261
|
+
AugmentResources(
|
|
262
|
+
given=ordered_resources,
|
|
263
|
+
title=True,
|
|
264
|
+
summary=True,
|
|
265
|
+
fields=AugmentResourceFields(
|
|
266
|
+
text=True,
|
|
267
|
+
filters=[],
|
|
268
|
+
),
|
|
269
|
+
)
|
|
270
|
+
]
|
|
271
|
+
),
|
|
319
272
|
)
|
|
273
|
+
|
|
274
|
+
extracted_texts = {}
|
|
275
|
+
for rid, resource in augmented.resources.items():
|
|
276
|
+
if resource.title is not None:
|
|
277
|
+
field_id = FieldId(rid=rid, type="a", key="title").full()
|
|
278
|
+
extracted_texts[field_id] = resource.title
|
|
279
|
+
if resource.summary is not None:
|
|
280
|
+
field_id = FieldId(rid=rid, type="a", key="summary").full()
|
|
281
|
+
extracted_texts[field_id] = resource.summary
|
|
282
|
+
|
|
283
|
+
for field_id, field in augmented.fields.items():
|
|
284
|
+
field = cast(AugmentedField, field)
|
|
285
|
+
if field.text is not None:
|
|
286
|
+
extracted_texts[field_id] = field.text
|
|
287
|
+
|
|
320
288
|
added_fields = set()
|
|
321
|
-
for
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
for
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
text=extracted_text,
|
|
335
|
-
augmentation_type=TextBlockAugmentationType.FULL_RESOURCE,
|
|
336
|
-
)
|
|
289
|
+
for field_id, extracted_text in extracted_texts.items():
|
|
290
|
+
# First off, remove the text block ids from paragraphs that belong to
|
|
291
|
+
# the same field, as otherwise the context will be duplicated.
|
|
292
|
+
for tb_id in context.text_block_ids():
|
|
293
|
+
if tb_id.startswith(field_id):
|
|
294
|
+
del context[tb_id]
|
|
295
|
+
# Add the extracted text of each field to the context.
|
|
296
|
+
context[field_id] = extracted_text
|
|
297
|
+
augmented_context.fields[field_id] = AugmentedTextBlock(
|
|
298
|
+
id=field_id,
|
|
299
|
+
text=extracted_text,
|
|
300
|
+
augmentation_type=TextBlockAugmentationType.FULL_RESOURCE,
|
|
301
|
+
)
|
|
337
302
|
|
|
338
|
-
|
|
303
|
+
added_fields.add(field_id)
|
|
339
304
|
|
|
340
305
|
metrics.set("full_resource_ops", len(added_fields))
|
|
341
306
|
|
|
@@ -353,213 +318,167 @@ async def extend_prompt_context_with_metadata(
|
|
|
353
318
|
metrics: Metrics,
|
|
354
319
|
augmented_context: AugmentedContext,
|
|
355
320
|
) -> None:
|
|
321
|
+
rids: list[str] = []
|
|
322
|
+
field_ids: list[str] = []
|
|
356
323
|
text_block_ids: list[TextBlockId] = []
|
|
357
324
|
for text_block_id in context.text_block_ids():
|
|
358
325
|
try:
|
|
359
|
-
|
|
326
|
+
tb_id = parse_text_block_id(text_block_id)
|
|
360
327
|
except ValueError: # pragma: no cover
|
|
361
328
|
# Some text block ids are not paragraphs nor fields, so they are skipped
|
|
362
329
|
# (e.g. USER_CONTEXT_0, when the user provides extra context)
|
|
363
330
|
continue
|
|
331
|
+
|
|
332
|
+
field_id = tb_id if isinstance(tb_id, FieldId) else tb_id.field_id
|
|
333
|
+
|
|
334
|
+
text_block_ids.append(tb_id)
|
|
335
|
+
field_ids.append(field_id.full())
|
|
336
|
+
rids.append(tb_id.rid)
|
|
337
|
+
|
|
364
338
|
if len(text_block_ids) == 0: # pragma: no cover
|
|
365
339
|
return
|
|
366
340
|
|
|
341
|
+
resource_origin = False
|
|
342
|
+
resource_extra = False
|
|
343
|
+
classification_labels = False
|
|
344
|
+
field_entities = False
|
|
345
|
+
|
|
367
346
|
ops = 0
|
|
368
347
|
if MetadataExtensionType.ORIGIN in strategy.types:
|
|
369
348
|
ops += 1
|
|
370
|
-
|
|
371
|
-
context, kbid, text_block_ids, augmented_context
|
|
372
|
-
)
|
|
349
|
+
resource_origin = True
|
|
373
350
|
|
|
374
351
|
if MetadataExtensionType.CLASSIFICATION_LABELS in strategy.types:
|
|
375
352
|
ops += 1
|
|
376
|
-
|
|
377
|
-
context, kbid, text_block_ids, augmented_context
|
|
378
|
-
)
|
|
353
|
+
classification_labels = True
|
|
379
354
|
|
|
380
355
|
if MetadataExtensionType.NERS in strategy.types:
|
|
381
356
|
ops += 1
|
|
382
|
-
|
|
357
|
+
field_entities = True
|
|
383
358
|
|
|
384
359
|
if MetadataExtensionType.EXTRA_METADATA in strategy.types:
|
|
385
360
|
ops += 1
|
|
386
|
-
|
|
361
|
+
resource_extra = True
|
|
387
362
|
|
|
388
363
|
metrics.set("metadata_extension_ops", ops * len(text_block_ids))
|
|
389
364
|
|
|
365
|
+
augment_req = AugmentRequest()
|
|
366
|
+
if resource_origin or resource_extra or classification_labels:
|
|
367
|
+
augment_req.resources = [
|
|
368
|
+
AugmentResources(
|
|
369
|
+
given=rids,
|
|
370
|
+
origin=resource_origin,
|
|
371
|
+
extra=resource_extra,
|
|
372
|
+
classification_labels=classification_labels,
|
|
373
|
+
)
|
|
374
|
+
]
|
|
375
|
+
if classification_labels or field_entities:
|
|
376
|
+
augment_req.fields = [
|
|
377
|
+
AugmentFields(
|
|
378
|
+
given=field_ids,
|
|
379
|
+
classification_labels=classification_labels,
|
|
380
|
+
entities=field_entities,
|
|
381
|
+
)
|
|
382
|
+
]
|
|
390
383
|
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
return ParagraphId.from_string(text_block_id)
|
|
395
|
-
except ValueError:
|
|
396
|
-
# When we're doing `full_resource` or `hierarchy` strategies,the text block id
|
|
397
|
-
# is a field id
|
|
398
|
-
return FieldId.from_string(text_block_id)
|
|
384
|
+
if augment_req.resources is None and augment_req.fields is None:
|
|
385
|
+
# nothing to augment
|
|
386
|
+
return
|
|
399
387
|
|
|
388
|
+
augmented = await rpc.augment(kbid, augment_req)
|
|
400
389
|
|
|
401
|
-
async def extend_prompt_context_with_origin_metadata(
|
|
402
|
-
context: CappedPromptContext,
|
|
403
|
-
kbid,
|
|
404
|
-
text_block_ids: list[TextBlockId],
|
|
405
|
-
augmented_context: AugmentedContext,
|
|
406
|
-
):
|
|
407
|
-
async def _get_origin(kbid: str, rid: str) -> tuple[str, Optional[Origin]]:
|
|
408
|
-
origin = None
|
|
409
|
-
resource = await cache.get_resource(kbid, rid)
|
|
410
|
-
if resource is not None:
|
|
411
|
-
pb_origin = await resource.get_origin()
|
|
412
|
-
if pb_origin is not None:
|
|
413
|
-
origin = from_proto.origin(pb_origin)
|
|
414
|
-
return rid, origin
|
|
415
|
-
|
|
416
|
-
rids = {tb_id.rid for tb_id in text_block_ids}
|
|
417
|
-
origins = await run_concurrently([_get_origin(kbid, rid) for rid in rids])
|
|
418
|
-
rid_to_origin = {rid: origin for rid, origin in origins if origin is not None}
|
|
419
390
|
for tb_id in text_block_ids:
|
|
420
|
-
|
|
421
|
-
if origin is not None and tb_id.full() in context:
|
|
422
|
-
text = context.output.pop(tb_id.full())
|
|
423
|
-
extended_text = text + f"\n\nDOCUMENT METADATA AT ORIGIN:\n{to_yaml(origin)}"
|
|
424
|
-
context[tb_id.full()] = extended_text
|
|
425
|
-
augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
|
|
426
|
-
id=tb_id.full(),
|
|
427
|
-
text=extended_text,
|
|
428
|
-
parent=tb_id.full(),
|
|
429
|
-
augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
|
|
430
|
-
)
|
|
391
|
+
field_id = tb_id if isinstance(tb_id, FieldId) else tb_id.field_id
|
|
431
392
|
|
|
393
|
+
resource = augmented.resources.get(tb_id.rid)
|
|
394
|
+
field = augmented.fields.get(field_id.full())
|
|
432
395
|
|
|
433
|
-
async def extend_prompt_context_with_classification_labels(
|
|
434
|
-
context: CappedPromptContext,
|
|
435
|
-
kbid: str,
|
|
436
|
-
text_block_ids: list[TextBlockId],
|
|
437
|
-
augmented_context: AugmentedContext,
|
|
438
|
-
):
|
|
439
|
-
async def _get_labels(kbid: str, _id: TextBlockId) -> tuple[TextBlockId, list[tuple[str, str]]]:
|
|
440
|
-
fid = _id if isinstance(_id, FieldId) else _id.field_id
|
|
441
|
-
labels = set()
|
|
442
|
-
resource = await cache.get_resource(kbid, fid.rid)
|
|
443
396
|
if resource is not None:
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
continue
|
|
455
|
-
labels.add((classif.labelset, classif.label))
|
|
456
|
-
return _id, list(labels)
|
|
457
|
-
|
|
458
|
-
classif_labels = await run_concurrently([_get_labels(kbid, tb_id) for tb_id in text_block_ids])
|
|
459
|
-
tb_id_to_labels = {tb_id: labels for tb_id, labels in classif_labels if len(labels) > 0}
|
|
460
|
-
for tb_id in text_block_ids:
|
|
461
|
-
labels = tb_id_to_labels.get(tb_id)
|
|
462
|
-
if labels is not None and tb_id.full() in context:
|
|
463
|
-
text = context.output.pop(tb_id.full())
|
|
464
|
-
|
|
465
|
-
labels_text = "DOCUMENT CLASSIFICATION LABELS:"
|
|
466
|
-
for labelset, label in labels:
|
|
467
|
-
labels_text += f"\n - {label} ({labelset})"
|
|
468
|
-
extended_text = text + "\n\n" + labels_text
|
|
469
|
-
|
|
470
|
-
context[tb_id.full()] = extended_text
|
|
471
|
-
augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
|
|
472
|
-
id=tb_id.full(),
|
|
473
|
-
text=extended_text,
|
|
474
|
-
parent=tb_id.full(),
|
|
475
|
-
augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
|
|
476
|
-
)
|
|
397
|
+
if resource.origin is not None:
|
|
398
|
+
text = context.output.pop(tb_id.full())
|
|
399
|
+
extended_text = text + f"\n\nDOCUMENT METADATA AT ORIGIN:\n{to_yaml(resource.origin)}"
|
|
400
|
+
context[tb_id.full()] = extended_text
|
|
401
|
+
augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
|
|
402
|
+
id=tb_id.full(),
|
|
403
|
+
text=extended_text,
|
|
404
|
+
parent=tb_id.full(),
|
|
405
|
+
augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
|
|
406
|
+
)
|
|
477
407
|
|
|
408
|
+
if resource.extra is not None:
|
|
409
|
+
text = context.output.pop(tb_id.full())
|
|
410
|
+
extended_text = text + f"\n\nDOCUMENT EXTRA METADATA:\n{to_yaml(resource.extra)}"
|
|
411
|
+
context[tb_id.full()] = extended_text
|
|
412
|
+
augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
|
|
413
|
+
id=tb_id.full(),
|
|
414
|
+
text=extended_text,
|
|
415
|
+
parent=tb_id.full(),
|
|
416
|
+
augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
|
|
417
|
+
)
|
|
478
418
|
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
nerss = await run_concurrently([_get_ners(kbid, tb_id) for tb_id in text_block_ids])
|
|
507
|
-
tb_id_to_ners = {tb_id: ners for tb_id, ners in nerss if len(ners) > 0}
|
|
508
|
-
for tb_id in text_block_ids:
|
|
509
|
-
ners = tb_id_to_ners.get(tb_id)
|
|
510
|
-
if ners is not None and tb_id.full() in context:
|
|
511
|
-
text = context.output.pop(tb_id.full())
|
|
512
|
-
|
|
513
|
-
ners_text = "DOCUMENT NAMED ENTITIES (NERs):"
|
|
514
|
-
for family, tokens in ners.items():
|
|
515
|
-
ners_text += f"\n - {family}:"
|
|
516
|
-
for token in sorted(list(tokens)):
|
|
517
|
-
ners_text += f"\n - {token}"
|
|
518
|
-
|
|
519
|
-
extended_text = text + "\n\n" + ners_text
|
|
520
|
-
|
|
521
|
-
context[tb_id.full()] = extended_text
|
|
522
|
-
augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
|
|
523
|
-
id=tb_id.full(),
|
|
524
|
-
text=extended_text,
|
|
525
|
-
parent=tb_id.full(),
|
|
526
|
-
augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
|
|
527
|
-
)
|
|
419
|
+
if tb_id.full() in context:
|
|
420
|
+
if (resource is not None and resource.classification_labels) or (
|
|
421
|
+
field is not None and field.classification_labels
|
|
422
|
+
):
|
|
423
|
+
text = context.output.pop(tb_id.full())
|
|
424
|
+
|
|
425
|
+
labels_text = "DOCUMENT CLASSIFICATION LABELS:"
|
|
426
|
+
if resource is not None and resource.classification_labels:
|
|
427
|
+
for labelset, labels in resource.classification_labels.items():
|
|
428
|
+
for label in labels:
|
|
429
|
+
labels_text += f"\n - {label} ({labelset})"
|
|
430
|
+
|
|
431
|
+
if field is not None and field.classification_labels:
|
|
432
|
+
for labelset, labels in field.classification_labels.items():
|
|
433
|
+
for label in labels:
|
|
434
|
+
labels_text += f"\n - {label} ({labelset})"
|
|
435
|
+
|
|
436
|
+
extended_text = text + "\n\n" + labels_text
|
|
437
|
+
|
|
438
|
+
context[tb_id.full()] = extended_text
|
|
439
|
+
augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
|
|
440
|
+
id=tb_id.full(),
|
|
441
|
+
text=extended_text,
|
|
442
|
+
parent=tb_id.full(),
|
|
443
|
+
augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
|
|
444
|
+
)
|
|
528
445
|
|
|
446
|
+
if field is not None and field.entities:
|
|
447
|
+
ners = field.entities
|
|
529
448
|
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
):
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
parent=tb_id.full(),
|
|
558
|
-
augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
|
|
559
|
-
)
|
|
449
|
+
text = context.output.pop(tb_id.full())
|
|
450
|
+
|
|
451
|
+
ners_text = "DOCUMENT NAMED ENTITIES (NERs):"
|
|
452
|
+
for family, tokens in ners.items():
|
|
453
|
+
ners_text += f"\n - {family}:"
|
|
454
|
+
for token in sorted(list(tokens)):
|
|
455
|
+
ners_text += f"\n - {token}"
|
|
456
|
+
|
|
457
|
+
extended_text = text + "\n\n" + ners_text
|
|
458
|
+
|
|
459
|
+
context[tb_id.full()] = extended_text
|
|
460
|
+
augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
|
|
461
|
+
id=tb_id.full(),
|
|
462
|
+
text=extended_text,
|
|
463
|
+
parent=tb_id.full(),
|
|
464
|
+
augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def parse_text_block_id(text_block_id: str) -> TextBlockId:
|
|
469
|
+
try:
|
|
470
|
+
# Typically, the text block id is a paragraph id
|
|
471
|
+
return ParagraphId.from_string(text_block_id)
|
|
472
|
+
except ValueError:
|
|
473
|
+
# When we're doing `full_resource` or `hierarchy` strategies,the text block id
|
|
474
|
+
# is a field id
|
|
475
|
+
return FieldId.from_string(text_block_id)
|
|
560
476
|
|
|
561
477
|
|
|
562
478
|
def to_yaml(obj: BaseModel) -> str:
|
|
479
|
+
# FIXME: this dumps enums REALLY poorly, e.g.,
|
|
480
|
+
# `!!python/object/apply:nucliadb_models.metadata.Source\n- WEB` for
|
|
481
|
+
# Source.WEB instead of `WEB`
|
|
563
482
|
return yaml.dump(
|
|
564
483
|
obj.model_dump(exclude_none=True, exclude_defaults=True, exclude_unset=True),
|
|
565
484
|
default_flow_style=False,
|
|
@@ -589,37 +508,74 @@ async def field_extension_prompt_context(
|
|
|
589
508
|
if resource_uuid not in ordered_resources:
|
|
590
509
|
ordered_resources.append(resource_uuid)
|
|
591
510
|
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
511
|
+
resource_title = False
|
|
512
|
+
resource_summary = False
|
|
513
|
+
filters: list[nucliadb_models.filters.Field | nucliadb_models.filters.Generated] = []
|
|
514
|
+
# this strategy exposes a way to access resource title and summary using a
|
|
515
|
+
# field id. However, as they are resource properties, we must request it as
|
|
516
|
+
# that
|
|
517
|
+
for name in strategy.fields:
|
|
518
|
+
if name == "a/title":
|
|
519
|
+
resource_title = True
|
|
520
|
+
elif name == "a/summary":
|
|
521
|
+
resource_summary = True
|
|
522
|
+
else:
|
|
523
|
+
# model already enforces type/name format
|
|
524
|
+
field_type, field_name = name.split("/")
|
|
525
|
+
filters.append(
|
|
526
|
+
nucliadb_models.filters.Field(
|
|
527
|
+
type=FIELD_TYPE_STR_TO_NAME[field_type], name=field_name or None
|
|
528
|
+
)
|
|
529
|
+
)
|
|
603
530
|
|
|
604
|
-
|
|
605
|
-
|
|
531
|
+
for da_prefix in strategy.data_augmentation_field_prefixes:
|
|
532
|
+
filters.append(nucliadb_models.filters.Generated(by="data-augmentation", da_task=da_prefix))
|
|
533
|
+
|
|
534
|
+
augmented = await rpc.augment(
|
|
535
|
+
kbid,
|
|
536
|
+
AugmentRequest(
|
|
537
|
+
resources=[
|
|
538
|
+
AugmentResources(
|
|
539
|
+
given=ordered_resources,
|
|
540
|
+
title=resource_title,
|
|
541
|
+
summary=resource_summary,
|
|
542
|
+
fields=AugmentResourceFields(
|
|
543
|
+
text=True,
|
|
544
|
+
filters=filters,
|
|
545
|
+
),
|
|
546
|
+
)
|
|
547
|
+
]
|
|
548
|
+
),
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
# REVIEW(decoupled-ask): we don't have the field count anymore, is this good enough?
|
|
552
|
+
metrics.set("field_extension_ops", len(ordered_resources))
|
|
606
553
|
|
|
607
|
-
|
|
554
|
+
extracted_texts = {}
|
|
555
|
+
# now we need to expose title and summary as fields again, so it gets
|
|
556
|
+
# consistent with the view we are providing in the API
|
|
557
|
+
for rid, augmented_resource in augmented.resources.items():
|
|
558
|
+
if augmented_resource.title:
|
|
559
|
+
extracted_texts[f"{rid}/a/title"] = augmented_resource.title
|
|
560
|
+
if augmented_resource.summary:
|
|
561
|
+
extracted_texts[f"{rid}/a/summary"] = augmented_resource.summary
|
|
608
562
|
|
|
609
|
-
for
|
|
610
|
-
if
|
|
563
|
+
for fid, augmented_field in augmented.fields.items():
|
|
564
|
+
if augmented_field is None or augmented_field.text is None: # pragma: no cover
|
|
611
565
|
continue
|
|
612
|
-
|
|
566
|
+
extracted_texts[fid] = augmented_field.text
|
|
567
|
+
|
|
568
|
+
for fid, extracted_text in extracted_texts.items():
|
|
613
569
|
# First off, remove the text block ids from paragraphs that belong to
|
|
614
570
|
# the same field, as otherwise the context will be duplicated.
|
|
615
571
|
for tb_id in context.text_block_ids():
|
|
616
|
-
if tb_id.startswith(
|
|
572
|
+
if tb_id.startswith(fid):
|
|
617
573
|
del context[tb_id]
|
|
618
574
|
# Add the extracted text of each field to the beginning of the context.
|
|
619
|
-
if
|
|
620
|
-
context[
|
|
621
|
-
augmented_context.fields[
|
|
622
|
-
id=
|
|
575
|
+
if fid not in context:
|
|
576
|
+
context[fid] = extracted_text
|
|
577
|
+
augmented_context.fields[fid] = AugmentedTextBlock(
|
|
578
|
+
id=fid,
|
|
623
579
|
text=extracted_text,
|
|
624
580
|
augmentation_type=TextBlockAugmentationType.FIELD_EXTENSION,
|
|
625
581
|
)
|
|
@@ -630,13 +586,6 @@ async def field_extension_prompt_context(
|
|
|
630
586
|
context[paragraph.id] = _clean_paragraph_text(paragraph)
|
|
631
587
|
|
|
632
588
|
|
|
633
|
-
async def get_orm_field(kbid: str, field_id: FieldId) -> Optional[Field]:
|
|
634
|
-
resource = await cache.get_resource(kbid, field_id.rid)
|
|
635
|
-
if resource is None: # pragma: no cover
|
|
636
|
-
return None
|
|
637
|
-
return await resource.get_field(key=field_id.key, type=field_id.pb_type, load=False)
|
|
638
|
-
|
|
639
|
-
|
|
640
589
|
async def neighbouring_paragraphs_prompt_context(
|
|
641
590
|
context: CappedPromptContext,
|
|
642
591
|
kbid: str,
|
|
@@ -652,83 +601,52 @@ async def neighbouring_paragraphs_prompt_context(
|
|
|
652
601
|
retrieved_paragraphs_ids = [
|
|
653
602
|
ParagraphId.from_string(text_block.id) for text_block in ordered_text_blocks
|
|
654
603
|
]
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
}
|
|
670
|
-
extracted_texts: dict[FieldId, ExtractedText] = {
|
|
671
|
-
fid: et for fid, et in zip(unique_field_ids, await asyncio.gather(*et_ops)) if et is not None
|
|
672
|
-
}
|
|
673
|
-
|
|
674
|
-
def _get_paragraph_text(extracted_text: ExtractedText, pid: ParagraphId) -> str:
|
|
675
|
-
if pid.field_id.subfield_id:
|
|
676
|
-
text = extracted_text.split_text.get(pid.field_id.subfield_id) or ""
|
|
677
|
-
else:
|
|
678
|
-
text = extracted_text.text
|
|
679
|
-
return text[pid.paragraph_start : pid.paragraph_end]
|
|
604
|
+
|
|
605
|
+
augmented = await rpc.augment(
|
|
606
|
+
kbid,
|
|
607
|
+
AugmentRequest(
|
|
608
|
+
paragraphs=[
|
|
609
|
+
AugmentParagraphs(
|
|
610
|
+
given=[AugmentParagraph(id=pid.full()) for pid in retrieved_paragraphs_ids],
|
|
611
|
+
text=True,
|
|
612
|
+
neighbours_before=strategy.before,
|
|
613
|
+
neighbours_after=strategy.after,
|
|
614
|
+
)
|
|
615
|
+
]
|
|
616
|
+
),
|
|
617
|
+
)
|
|
680
618
|
|
|
681
619
|
for pid in retrieved_paragraphs_ids:
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
if field_extracted_text is None:
|
|
620
|
+
paragraph = augmented.paragraphs.get(pid.full())
|
|
621
|
+
if paragraph is None:
|
|
685
622
|
continue
|
|
686
|
-
|
|
623
|
+
|
|
624
|
+
ptext = paragraph.text or ""
|
|
687
625
|
if ptext and pid.full() not in context:
|
|
688
626
|
context[pid.full()] = ptext
|
|
689
627
|
|
|
690
628
|
# Now add the neighbouring paragraphs
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
field_pids = [
|
|
696
|
-
ParagraphId(
|
|
697
|
-
field_id=pid.field_id,
|
|
698
|
-
paragraph_start=p.start,
|
|
699
|
-
paragraph_end=p.end,
|
|
700
|
-
)
|
|
701
|
-
for p in field_extracted_metadata.metadata.paragraphs
|
|
629
|
+
neighbour_ids = [
|
|
630
|
+
*(paragraph.neighbours_before or []),
|
|
631
|
+
*(paragraph.neighbours_after or []),
|
|
702
632
|
]
|
|
703
|
-
|
|
704
|
-
|
|
705
|
-
|
|
706
|
-
continue
|
|
633
|
+
for npid in neighbour_ids:
|
|
634
|
+
neighbour = augmented.paragraphs.get(npid)
|
|
635
|
+
assert neighbour is not None, "augment should never return dangling paragraph references"
|
|
707
636
|
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
before=strategy.before,
|
|
711
|
-
after=strategy.after,
|
|
712
|
-
field_pids=field_pids,
|
|
713
|
-
):
|
|
714
|
-
if neighbour_index == index:
|
|
715
|
-
# Already handled above
|
|
637
|
+
if ParagraphId.from_string(npid) in retrieved_paragraphs_ids or npid in context:
|
|
638
|
+
# already added
|
|
716
639
|
continue
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
continue
|
|
721
|
-
if npid in retrieved_paragraphs_ids or npid.full() in context:
|
|
722
|
-
# Already added
|
|
723
|
-
continue
|
|
724
|
-
ptext = _get_paragraph_text(field_extracted_text, npid)
|
|
725
|
-
if not ptext:
|
|
640
|
+
|
|
641
|
+
ntext = neighbour.text
|
|
642
|
+
if not ntext:
|
|
726
643
|
continue
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
644
|
+
|
|
645
|
+
context[npid] = ntext
|
|
646
|
+
augmented_context.paragraphs[npid] = AugmentedTextBlock(
|
|
647
|
+
id=npid,
|
|
648
|
+
text=ntext,
|
|
649
|
+
position=neighbour.position,
|
|
732
650
|
parent=pid.full(),
|
|
733
651
|
augmentation_type=TextBlockAugmentationType.NEIGHBOURING_PARAGRAPHS,
|
|
734
652
|
)
|
|
@@ -738,7 +656,7 @@ async def neighbouring_paragraphs_prompt_context(
|
|
|
738
656
|
|
|
739
657
|
def get_text_position(
|
|
740
658
|
paragraph_id: ParagraphId, index: int, field_metadata: FieldComputedMetadata
|
|
741
|
-
) ->
|
|
659
|
+
) -> TextPosition | None:
|
|
742
660
|
if paragraph_id.field_id.subfield_id:
|
|
743
661
|
metadata = field_metadata.split_metadata[paragraph_id.field_id.subfield_id]
|
|
744
662
|
else:
|
|
@@ -777,148 +695,144 @@ async def conversation_prompt_context(
|
|
|
777
695
|
metrics: Metrics,
|
|
778
696
|
augmented_context: AugmentedContext,
|
|
779
697
|
):
|
|
780
|
-
analyzed_fields:
|
|
698
|
+
analyzed_fields: list[str] = []
|
|
781
699
|
ops = 0
|
|
782
|
-
async with get_driver().ro_transaction() as txn:
|
|
783
|
-
storage = await get_storage()
|
|
784
|
-
kb = KnowledgeBoxORM(txn, storage, kbid)
|
|
785
|
-
for paragraph in ordered_paragraphs:
|
|
786
|
-
if paragraph.id not in context:
|
|
787
|
-
context[paragraph.id] = _clean_paragraph_text(paragraph)
|
|
788
700
|
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
701
|
+
conversation_paragraphs = []
|
|
702
|
+
for paragraph in ordered_paragraphs:
|
|
703
|
+
if paragraph.id not in context:
|
|
704
|
+
context[paragraph.id] = _clean_paragraph_text(paragraph)
|
|
705
|
+
|
|
706
|
+
parent_paragraph_id = ParagraphId.from_string(paragraph.id)
|
|
707
|
+
|
|
708
|
+
if parent_paragraph_id.field_id.type != FieldTypeName.CONVERSATION.abbreviation():
|
|
709
|
+
# conversational strategy only applies to conversation fields
|
|
710
|
+
continue
|
|
711
|
+
|
|
712
|
+
field_unique_id = parent_paragraph_id.field_id.full_without_subfield()
|
|
713
|
+
if field_unique_id in analyzed_fields:
|
|
714
|
+
continue
|
|
715
|
+
|
|
716
|
+
conversation_paragraphs.append((parent_paragraph_id, paragraph))
|
|
717
|
+
|
|
718
|
+
# augment conversation paragraphs
|
|
719
|
+
|
|
720
|
+
if strategy.full:
|
|
721
|
+
full_conversation = True
|
|
722
|
+
max_conversation_messages = None
|
|
723
|
+
else:
|
|
724
|
+
full_conversation = False
|
|
725
|
+
max_conversation_messages = strategy.max_messages
|
|
726
|
+
|
|
727
|
+
augment = AugmentRequest(
|
|
728
|
+
fields=[
|
|
729
|
+
AugmentFields(
|
|
730
|
+
given=[paragraph_id.field_id.full() for paragraph_id, _ in conversation_paragraphs],
|
|
731
|
+
full_conversation=full_conversation,
|
|
732
|
+
max_conversation_messages=max_conversation_messages,
|
|
733
|
+
conversation_text_attachments=strategy.attachments_text,
|
|
734
|
+
conversation_image_attachments=strategy.attachments_images,
|
|
735
|
+
)
|
|
736
|
+
]
|
|
737
|
+
)
|
|
738
|
+
augmented = await rpc.augment(kbid, augment)
|
|
739
|
+
|
|
740
|
+
attachments: dict[ParagraphId, list[FieldId]] = {}
|
|
741
|
+
for parent_paragraph_id, paragraph in conversation_paragraphs:
|
|
742
|
+
fid = parent_paragraph_id.field_id
|
|
743
|
+
field = augmented.fields.get(fid.full_without_subfield())
|
|
744
|
+
if field is not None:
|
|
745
|
+
field = cast(AugmentedConversationField, field)
|
|
746
|
+
for _message in field.messages or []:
|
|
747
|
+
ops += 1
|
|
748
|
+
if not _message.text:
|
|
749
|
+
continue
|
|
750
|
+
|
|
751
|
+
text = _message.text
|
|
752
|
+
pid = ParagraphId(
|
|
753
|
+
field_id=FieldId(
|
|
754
|
+
rid=fid.rid,
|
|
755
|
+
type=fid.type,
|
|
756
|
+
key=fid.key,
|
|
757
|
+
subfield_id=_message.ident,
|
|
758
|
+
),
|
|
759
|
+
paragraph_start=0,
|
|
760
|
+
paragraph_end=len(text),
|
|
761
|
+
).full()
|
|
762
|
+
if pid in context:
|
|
800
763
|
continue
|
|
801
|
-
|
|
802
|
-
|
|
764
|
+
context[pid] = text
|
|
765
|
+
|
|
766
|
+
attachments.setdefault(parent_paragraph_id, []).extend(
|
|
767
|
+
[FieldId.from_string(attachment_id) for attachment_id in field.attachments or []]
|
|
768
|
+
)
|
|
769
|
+
augmented_context.paragraphs[pid] = AugmentedTextBlock(
|
|
770
|
+
id=pid,
|
|
771
|
+
text=text,
|
|
772
|
+
parent=paragraph.id,
|
|
773
|
+
augmentation_type=TextBlockAugmentationType.CONVERSATION,
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
# augment attachments
|
|
777
|
+
|
|
778
|
+
if strategy.attachments_text or (
|
|
779
|
+
(strategy.attachments_images and visual_llm) and len(attachments) > 0
|
|
780
|
+
):
|
|
781
|
+
augment = AugmentRequest(
|
|
782
|
+
fields=[
|
|
783
|
+
AugmentFields(
|
|
784
|
+
given=[
|
|
785
|
+
id.full()
|
|
786
|
+
for paragraph_attachments in attachments.values()
|
|
787
|
+
for id in paragraph_attachments
|
|
788
|
+
],
|
|
789
|
+
text=strategy.attachments_text,
|
|
790
|
+
file_thumbnail=(strategy.attachments_images and visual_llm),
|
|
791
|
+
)
|
|
792
|
+
]
|
|
793
|
+
)
|
|
794
|
+
augmented = await rpc.augment(kbid, augment)
|
|
795
|
+
|
|
796
|
+
for parent_paragraph_id, paragraph_attachments in attachments.items():
|
|
797
|
+
for attachment_id in paragraph_attachments:
|
|
798
|
+
attachment_field = augmented.fields.get(attachment_id.full())
|
|
799
|
+
|
|
800
|
+
if attachment_field is None:
|
|
803
801
|
continue
|
|
804
802
|
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
) # type: ignore
|
|
808
|
-
cmetadata = await field_obj.get_metadata()
|
|
809
|
-
|
|
810
|
-
attachments: List[resources_pb2.FieldRef] = []
|
|
811
|
-
if strategy.full:
|
|
812
|
-
ops += 5
|
|
813
|
-
extracted_text = await field_obj.get_extracted_text()
|
|
814
|
-
for current_page in range(1, cmetadata.pages + 1):
|
|
815
|
-
conv = await field_obj.db_get_value(current_page)
|
|
816
|
-
|
|
817
|
-
for message in conv.messages:
|
|
818
|
-
ident = message.ident
|
|
819
|
-
if extracted_text is not None:
|
|
820
|
-
text = extracted_text.split_text.get(ident, message.content.text.strip())
|
|
821
|
-
else:
|
|
822
|
-
text = message.content.text.strip()
|
|
823
|
-
pid = f"{rid}/{field_type}/{field_id}/{ident}/0-{len(text) + 1}"
|
|
824
|
-
attachments.extend(message.content.attachments_fields)
|
|
825
|
-
if pid in context:
|
|
826
|
-
continue
|
|
827
|
-
context[pid] = text
|
|
828
|
-
augmented_context.paragraphs[pid] = AugmentedTextBlock(
|
|
829
|
-
id=pid,
|
|
830
|
-
text=text,
|
|
831
|
-
parent=paragraph.id,
|
|
832
|
-
augmentation_type=TextBlockAugmentationType.CONVERSATION,
|
|
833
|
-
)
|
|
834
|
-
else:
|
|
835
|
-
# Add first message
|
|
836
|
-
extracted_text = await field_obj.get_extracted_text()
|
|
837
|
-
first_page = await field_obj.db_get_value()
|
|
838
|
-
if len(first_page.messages) > 0:
|
|
839
|
-
message = first_page.messages[0]
|
|
840
|
-
ident = message.ident
|
|
841
|
-
if extracted_text is not None:
|
|
842
|
-
text = extracted_text.split_text.get(ident, message.content.text.strip())
|
|
843
|
-
else:
|
|
844
|
-
text = message.content.text.strip()
|
|
845
|
-
attachments.extend(message.content.attachments_fields)
|
|
846
|
-
pid = f"{rid}/{field_type}/{field_id}/{ident}/0-{len(text) + 1}"
|
|
847
|
-
if pid in context:
|
|
848
|
-
continue
|
|
849
|
-
context[pid] = text
|
|
850
|
-
augmented_context.paragraphs[pid] = AugmentedTextBlock(
|
|
851
|
-
id=pid,
|
|
852
|
-
text=text,
|
|
853
|
-
parent=paragraph.id,
|
|
854
|
-
augmentation_type=TextBlockAugmentationType.CONVERSATION,
|
|
855
|
-
)
|
|
803
|
+
if strategy.attachments_text and attachment_field.text:
|
|
804
|
+
ops += 1
|
|
856
805
|
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
for page in range(1, cmetadata.pages + 1):
|
|
861
|
-
# Collect the messages with the window asked by the user arround the match paragraph
|
|
862
|
-
conv = await field_obj.db_get_value(page)
|
|
863
|
-
for message in conv.messages:
|
|
864
|
-
messages.append(message)
|
|
865
|
-
if pending > 0:
|
|
866
|
-
pending -= 1
|
|
867
|
-
if message.ident == mident:
|
|
868
|
-
pending = (strategy.max_messages - 1) // 2
|
|
869
|
-
if pending == 0:
|
|
870
|
-
break
|
|
871
|
-
if pending == 0:
|
|
872
|
-
break
|
|
873
|
-
|
|
874
|
-
for message in messages:
|
|
875
|
-
ops += 1
|
|
876
|
-
text = message.content.text.strip()
|
|
877
|
-
attachments.extend(message.content.attachments_fields)
|
|
878
|
-
pid = f"{rid}/{field_type}/{field_id}/{message.ident}/0-{len(message.content.text) + 1}"
|
|
879
|
-
if pid in context:
|
|
880
|
-
continue
|
|
806
|
+
pid = f"{attachment_id.full_without_subfield()}/0-{len(attachment_field.text)}"
|
|
807
|
+
if pid not in context:
|
|
808
|
+
text = f"Attachment {attachment_id.key}: {attachment_field.text}\n\n"
|
|
881
809
|
context[pid] = text
|
|
882
810
|
augmented_context.paragraphs[pid] = AugmentedTextBlock(
|
|
883
811
|
id=pid,
|
|
884
812
|
text=text,
|
|
885
|
-
parent=
|
|
813
|
+
parent=parent_paragraph_id.full(),
|
|
886
814
|
augmentation_type=TextBlockAugmentationType.CONVERSATION,
|
|
887
815
|
)
|
|
888
816
|
|
|
889
|
-
if
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
)
|
|
909
|
-
|
|
910
|
-
if strategy.attachments_images and visual_llm:
|
|
911
|
-
for attachment in attachments:
|
|
912
|
-
ops += 1
|
|
913
|
-
file_field: File = await resource.get_field(
|
|
914
|
-
attachment.field_id, attachment.field_type, load=True
|
|
915
|
-
) # type: ignore
|
|
916
|
-
image = await get_file_thumbnail_image(file_field)
|
|
917
|
-
if image is not None:
|
|
918
|
-
pid = f"{rid}/f/{attachment.field_id}/0-0"
|
|
919
|
-
context.images[pid] = image
|
|
920
|
-
|
|
921
|
-
analyzed_fields.append(field_unique_id)
|
|
817
|
+
if (
|
|
818
|
+
(strategy.attachments_images and visual_llm)
|
|
819
|
+
and isinstance(attachment_field, AugmentedFileField)
|
|
820
|
+
and attachment_field.thumbnail_image
|
|
821
|
+
):
|
|
822
|
+
ops += 1
|
|
823
|
+
|
|
824
|
+
image = await rpc.download_image(
|
|
825
|
+
kbid,
|
|
826
|
+
attachment_id,
|
|
827
|
+
attachment_field.thumbnail_image,
|
|
828
|
+
# We assume the thumbnail is always generated as JPEG by Nuclia processing
|
|
829
|
+
mime_type="image/jpeg",
|
|
830
|
+
)
|
|
831
|
+
if image is not None:
|
|
832
|
+
pid = f"{attachment_id.rid}/f/{attachment_id.key}/0-0"
|
|
833
|
+
context.images[pid] = image
|
|
834
|
+
|
|
835
|
+
analyzed_fields.append(field_unique_id)
|
|
922
836
|
metrics.set("conversation_ops", ops)
|
|
923
837
|
|
|
924
838
|
|
|
@@ -939,66 +853,93 @@ async def hierarchy_prompt_context(
|
|
|
939
853
|
# Make a copy of the ordered paragraphs to avoid modifying the original list, which is returned
|
|
940
854
|
# in the response to the user
|
|
941
855
|
ordered_paragraphs_copy = copy.deepcopy(ordered_paragraphs)
|
|
942
|
-
resources:
|
|
856
|
+
resources: dict[str, ExtraCharsParagraph] = {}
|
|
943
857
|
|
|
944
858
|
# Iterate paragraphs to get extended text
|
|
859
|
+
paragraphs_to_augment = []
|
|
945
860
|
for paragraph in ordered_paragraphs_copy:
|
|
946
861
|
paragraph_id = ParagraphId.from_string(paragraph.id)
|
|
947
|
-
extended_paragraph_text = paragraph.text
|
|
948
|
-
if paragraphs_extra_characters > 0:
|
|
949
|
-
extended_paragraph_text = await get_paragraph_text(
|
|
950
|
-
kbid=kbid,
|
|
951
|
-
paragraph_id=paragraph_id,
|
|
952
|
-
log_on_missing_field=True,
|
|
953
|
-
)
|
|
954
862
|
rid = paragraph_id.rid
|
|
863
|
+
|
|
864
|
+
if paragraphs_extra_characters > 0:
|
|
865
|
+
paragraph_id.paragraph_end += paragraphs_extra_characters
|
|
866
|
+
|
|
867
|
+
paragraphs_to_augment.append(paragraph_id)
|
|
868
|
+
|
|
955
869
|
if rid not in resources:
|
|
956
870
|
# Get the title and the summary of the resource
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
type="a",
|
|
963
|
-
key="title",
|
|
964
|
-
),
|
|
965
|
-
paragraph_start=0,
|
|
966
|
-
paragraph_end=500,
|
|
871
|
+
title_paragraph_id = ParagraphId(
|
|
872
|
+
field_id=FieldId(
|
|
873
|
+
rid=rid,
|
|
874
|
+
type="a",
|
|
875
|
+
key="title",
|
|
967
876
|
),
|
|
968
|
-
|
|
877
|
+
paragraph_start=0,
|
|
878
|
+
paragraph_end=500,
|
|
969
879
|
)
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
type="a",
|
|
976
|
-
key="summary",
|
|
977
|
-
),
|
|
978
|
-
paragraph_start=0,
|
|
979
|
-
paragraph_end=1000,
|
|
880
|
+
summary_paragraph_id = ParagraphId(
|
|
881
|
+
field_id=FieldId(
|
|
882
|
+
rid=rid,
|
|
883
|
+
type="a",
|
|
884
|
+
key="summary",
|
|
980
885
|
),
|
|
981
|
-
|
|
886
|
+
paragraph_start=0,
|
|
887
|
+
paragraph_end=1000,
|
|
982
888
|
)
|
|
889
|
+
paragraphs_to_augment.append(title_paragraph_id)
|
|
890
|
+
paragraphs_to_augment.append(summary_paragraph_id)
|
|
891
|
+
|
|
983
892
|
resources[rid] = ExtraCharsParagraph(
|
|
984
|
-
title=
|
|
985
|
-
summary=
|
|
986
|
-
paragraphs=[(paragraph,
|
|
893
|
+
title=title_paragraph_id,
|
|
894
|
+
summary=summary_paragraph_id,
|
|
895
|
+
paragraphs=[(paragraph, paragraph_id)],
|
|
987
896
|
)
|
|
988
897
|
else:
|
|
989
|
-
resources[rid].paragraphs.append((paragraph,
|
|
898
|
+
resources[rid].paragraphs.append((paragraph, paragraph_id))
|
|
990
899
|
|
|
991
900
|
metrics.set("hierarchy_ops", len(resources))
|
|
901
|
+
|
|
902
|
+
augmented = await rpc.augment(
|
|
903
|
+
kbid,
|
|
904
|
+
AugmentRequest(
|
|
905
|
+
paragraphs=[
|
|
906
|
+
AugmentParagraphs(
|
|
907
|
+
given=[
|
|
908
|
+
AugmentParagraph(id=paragraph_id.full())
|
|
909
|
+
for paragraph_id in paragraphs_to_augment
|
|
910
|
+
],
|
|
911
|
+
text=True,
|
|
912
|
+
)
|
|
913
|
+
]
|
|
914
|
+
),
|
|
915
|
+
)
|
|
916
|
+
|
|
992
917
|
augmented_paragraphs = set()
|
|
993
918
|
|
|
994
919
|
# Modify the first paragraph of each resource to include the title and summary of the resource, as well as the
|
|
995
920
|
# extended paragraph text of all the paragraphs in the resource.
|
|
996
921
|
for values in resources.values():
|
|
997
|
-
|
|
998
|
-
|
|
922
|
+
augmented_title = augmented.paragraphs.get(values.title.full())
|
|
923
|
+
if augmented_title:
|
|
924
|
+
title_text = augmented_title.text or ""
|
|
925
|
+
else:
|
|
926
|
+
title_text = ""
|
|
927
|
+
|
|
928
|
+
augmented_summary = augmented.paragraphs.get(values.summary.full())
|
|
929
|
+
if augmented_summary:
|
|
930
|
+
summary_text = augmented_summary.text or ""
|
|
931
|
+
else:
|
|
932
|
+
summary_text = ""
|
|
933
|
+
|
|
999
934
|
first_paragraph = None
|
|
1000
935
|
text_with_hierarchy = ""
|
|
1001
|
-
for paragraph,
|
|
936
|
+
for paragraph, paragraph_id in values.paragraphs:
|
|
937
|
+
augmented_paragraph = augmented.paragraphs.get(paragraph_id.full())
|
|
938
|
+
if augmented_paragraph:
|
|
939
|
+
extended_paragraph_text = augmented_paragraph.text or ""
|
|
940
|
+
else:
|
|
941
|
+
extended_paragraph_text = ""
|
|
942
|
+
|
|
1002
943
|
if first_paragraph is None:
|
|
1003
944
|
first_paragraph = paragraph
|
|
1004
945
|
text_with_hierarchy += "\n EXTRACTED BLOCK: \n " + extended_paragraph_text + " \n\n "
|
|
@@ -1035,14 +976,14 @@ class PromptContextBuilder:
|
|
|
1035
976
|
self,
|
|
1036
977
|
kbid: str,
|
|
1037
978
|
ordered_paragraphs: list[FindParagraph],
|
|
1038
|
-
resource:
|
|
1039
|
-
user_context:
|
|
1040
|
-
user_image_context:
|
|
1041
|
-
strategies:
|
|
1042
|
-
image_strategies:
|
|
1043
|
-
max_context_characters:
|
|
979
|
+
resource: str | None = None,
|
|
980
|
+
user_context: list[str] | None = None,
|
|
981
|
+
user_image_context: list[Image] | None = None,
|
|
982
|
+
strategies: Sequence[RagStrategy] | None = None,
|
|
983
|
+
image_strategies: Sequence[ImageRagStrategy] | None = None,
|
|
984
|
+
max_context_characters: int | None = None,
|
|
1044
985
|
visual_llm: bool = False,
|
|
1045
|
-
query_image:
|
|
986
|
+
query_image: Image | None = None,
|
|
1046
987
|
metrics: Metrics = Metrics("prompt_context_builder"),
|
|
1047
988
|
):
|
|
1048
989
|
self.kbid = kbid
|
|
@@ -1088,10 +1029,10 @@ class PromptContextBuilder:
|
|
|
1088
1029
|
if self.image_strategies is None or len(self.image_strategies) == 0:
|
|
1089
1030
|
# Nothing to do
|
|
1090
1031
|
return
|
|
1091
|
-
page_image_strategy:
|
|
1032
|
+
page_image_strategy: PageImageStrategy | None = None
|
|
1092
1033
|
max_page_images = 5
|
|
1093
|
-
table_image_strategy:
|
|
1094
|
-
paragraph_image_strategy:
|
|
1034
|
+
table_image_strategy: TableImageStrategy | None = None
|
|
1035
|
+
paragraph_image_strategy: ParagraphImageStrategy | None = None
|
|
1095
1036
|
for strategy in self.image_strategies:
|
|
1096
1037
|
if strategy.name == ImageRagStrategyName.PAGE_IMAGE:
|
|
1097
1038
|
if page_image_strategy is None:
|
|
@@ -1121,7 +1062,12 @@ class PromptContextBuilder:
|
|
|
1121
1062
|
# page_image_id: rid/f/myfield/0
|
|
1122
1063
|
page_image_id = "/".join([pid.field_id.full(), str(paragraph_page_number)])
|
|
1123
1064
|
if page_image_id not in context.images:
|
|
1124
|
-
image = await
|
|
1065
|
+
image = await rpc.download_image(
|
|
1066
|
+
self.kbid,
|
|
1067
|
+
pid.field_id,
|
|
1068
|
+
f"generated/extracted_images_{paragraph_page_number}.png",
|
|
1069
|
+
mime_type="image/png",
|
|
1070
|
+
)
|
|
1125
1071
|
if image is not None:
|
|
1126
1072
|
ops += 1
|
|
1127
1073
|
context.images[page_image_id] = image
|
|
@@ -1141,7 +1087,9 @@ class PromptContextBuilder:
|
|
|
1141
1087
|
if (add_table or add_paragraph) and (
|
|
1142
1088
|
paragraph.reference is not None and paragraph.reference != ""
|
|
1143
1089
|
):
|
|
1144
|
-
pimage = await
|
|
1090
|
+
pimage = await rpc.download_image(
|
|
1091
|
+
self.kbid, pid.field_id, f"generated/{paragraph.reference}", mime_type="image/png"
|
|
1092
|
+
)
|
|
1145
1093
|
if pimage is not None:
|
|
1146
1094
|
ops += 1
|
|
1147
1095
|
context.images[paragraph.id] = pimage
|
|
@@ -1171,12 +1119,12 @@ class PromptContextBuilder:
|
|
|
1171
1119
|
RagStrategyName.GRAPH,
|
|
1172
1120
|
]
|
|
1173
1121
|
|
|
1174
|
-
full_resource:
|
|
1175
|
-
hierarchy:
|
|
1176
|
-
neighbouring_paragraphs:
|
|
1177
|
-
field_extension:
|
|
1178
|
-
metadata_extension:
|
|
1179
|
-
conversational_strategy:
|
|
1122
|
+
full_resource: FullResourceStrategy | None = None
|
|
1123
|
+
hierarchy: HierarchyResourceStrategy | None = None
|
|
1124
|
+
neighbouring_paragraphs: NeighbouringParagraphsStrategy | None = None
|
|
1125
|
+
field_extension: FieldExtensionStrategy | None = None
|
|
1126
|
+
metadata_extension: MetadataExtensionStrategy | None = None
|
|
1127
|
+
conversational_strategy: ConversationalStrategy | None = None
|
|
1180
1128
|
for strategy in self.strategies:
|
|
1181
1129
|
if strategy.name == RagStrategyName.FIELD_EXTENSION:
|
|
1182
1130
|
field_extension = cast(FieldExtensionStrategy, strategy)
|
|
@@ -1269,7 +1217,7 @@ class PromptContextBuilder:
|
|
|
1269
1217
|
)
|
|
1270
1218
|
|
|
1271
1219
|
|
|
1272
|
-
def get_paragraph_page_number(paragraph: FindParagraph) ->
|
|
1220
|
+
def get_paragraph_page_number(paragraph: FindParagraph) -> int | None:
|
|
1273
1221
|
if not paragraph.page_with_visual:
|
|
1274
1222
|
return None
|
|
1275
1223
|
if paragraph.position is None:
|
|
@@ -1279,9 +1227,9 @@ def get_paragraph_page_number(paragraph: FindParagraph) -> Optional[int]:
|
|
|
1279
1227
|
|
|
1280
1228
|
@dataclass
|
|
1281
1229
|
class ExtraCharsParagraph:
|
|
1282
|
-
title:
|
|
1283
|
-
summary:
|
|
1284
|
-
paragraphs:
|
|
1230
|
+
title: ParagraphId
|
|
1231
|
+
summary: ParagraphId
|
|
1232
|
+
paragraphs: list[tuple[FindParagraph, ParagraphId]]
|
|
1285
1233
|
|
|
1286
1234
|
|
|
1287
1235
|
def _clean_paragraph_text(paragraph: FindParagraph) -> str:
|