nucliadb 6.7.2.post4874__py3-none-any.whl → 6.10.0.post5705__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migrations/0023_backfill_pg_catalog.py +8 -4
- migrations/0028_extracted_vectors_reference.py +1 -1
- migrations/0029_backfill_field_status.py +3 -4
- migrations/0032_remove_old_relations.py +2 -3
- migrations/0038_backfill_catalog_field_labels.py +8 -4
- migrations/0039_backfill_converation_splits_metadata.py +106 -0
- migrations/0040_migrate_search_configurations.py +79 -0
- migrations/0041_reindex_conversations.py +137 -0
- migrations/pg/0010_shards_index.py +34 -0
- nucliadb/search/api/v1/resource/utils.py → migrations/pg/0011_catalog_statistics.py +5 -6
- migrations/pg/0012_catalog_statistics_undo.py +26 -0
- nucliadb/backups/create.py +2 -15
- nucliadb/backups/restore.py +4 -15
- nucliadb/backups/tasks.py +4 -1
- nucliadb/common/back_pressure/cache.py +2 -3
- nucliadb/common/back_pressure/materializer.py +7 -13
- nucliadb/common/back_pressure/settings.py +6 -6
- nucliadb/common/back_pressure/utils.py +1 -0
- nucliadb/common/cache.py +9 -9
- nucliadb/common/catalog/__init__.py +79 -0
- nucliadb/common/catalog/dummy.py +36 -0
- nucliadb/common/catalog/interface.py +85 -0
- nucliadb/{search/search/pgcatalog.py → common/catalog/pg.py} +330 -232
- nucliadb/common/catalog/utils.py +56 -0
- nucliadb/common/cluster/manager.py +8 -23
- nucliadb/common/cluster/rebalance.py +484 -112
- nucliadb/common/cluster/rollover.py +36 -9
- nucliadb/common/cluster/settings.py +4 -9
- nucliadb/common/cluster/utils.py +34 -8
- nucliadb/common/context/__init__.py +7 -8
- nucliadb/common/context/fastapi.py +1 -2
- nucliadb/common/datamanagers/__init__.py +2 -4
- nucliadb/common/datamanagers/atomic.py +9 -2
- nucliadb/common/datamanagers/cluster.py +1 -2
- nucliadb/common/datamanagers/fields.py +3 -4
- nucliadb/common/datamanagers/kb.py +6 -6
- nucliadb/common/datamanagers/labels.py +2 -3
- nucliadb/common/datamanagers/resources.py +10 -33
- nucliadb/common/datamanagers/rollover.py +5 -7
- nucliadb/common/datamanagers/search_configurations.py +1 -2
- nucliadb/common/datamanagers/synonyms.py +1 -2
- nucliadb/common/datamanagers/utils.py +4 -4
- nucliadb/common/datamanagers/vectorsets.py +4 -4
- nucliadb/common/external_index_providers/base.py +32 -5
- nucliadb/common/external_index_providers/manager.py +5 -34
- nucliadb/common/external_index_providers/settings.py +1 -27
- nucliadb/common/filter_expression.py +129 -41
- nucliadb/common/http_clients/exceptions.py +8 -0
- nucliadb/common/http_clients/processing.py +16 -23
- nucliadb/common/http_clients/utils.py +3 -0
- nucliadb/common/ids.py +82 -58
- nucliadb/common/locking.py +1 -2
- nucliadb/common/maindb/driver.py +9 -8
- nucliadb/common/maindb/local.py +5 -5
- nucliadb/common/maindb/pg.py +9 -8
- nucliadb/common/nidx.py +22 -5
- nucliadb/common/vector_index_config.py +1 -1
- nucliadb/export_import/datamanager.py +4 -3
- nucliadb/export_import/exporter.py +11 -19
- nucliadb/export_import/importer.py +13 -6
- nucliadb/export_import/tasks.py +2 -0
- nucliadb/export_import/utils.py +6 -18
- nucliadb/health.py +2 -2
- nucliadb/ingest/app.py +8 -8
- nucliadb/ingest/consumer/consumer.py +8 -10
- nucliadb/ingest/consumer/pull.py +10 -8
- nucliadb/ingest/consumer/service.py +5 -30
- nucliadb/ingest/consumer/shard_creator.py +16 -5
- nucliadb/ingest/consumer/utils.py +1 -1
- nucliadb/ingest/fields/base.py +37 -49
- nucliadb/ingest/fields/conversation.py +55 -9
- nucliadb/ingest/fields/exceptions.py +1 -2
- nucliadb/ingest/fields/file.py +22 -8
- nucliadb/ingest/fields/link.py +7 -7
- nucliadb/ingest/fields/text.py +2 -3
- nucliadb/ingest/orm/brain_v2.py +89 -57
- nucliadb/ingest/orm/broker_message.py +2 -4
- nucliadb/ingest/orm/entities.py +10 -209
- nucliadb/ingest/orm/index_message.py +128 -113
- nucliadb/ingest/orm/knowledgebox.py +91 -59
- nucliadb/ingest/orm/processor/auditing.py +1 -3
- nucliadb/ingest/orm/processor/data_augmentation.py +1 -2
- nucliadb/ingest/orm/processor/processor.py +98 -153
- nucliadb/ingest/orm/processor/sequence_manager.py +1 -2
- nucliadb/ingest/orm/resource.py +82 -71
- nucliadb/ingest/orm/utils.py +1 -1
- nucliadb/ingest/partitions.py +12 -1
- nucliadb/ingest/processing.py +17 -17
- nucliadb/ingest/serialize.py +202 -145
- nucliadb/ingest/service/writer.py +15 -114
- nucliadb/ingest/settings.py +36 -15
- nucliadb/ingest/utils.py +1 -2
- nucliadb/learning_proxy.py +23 -26
- nucliadb/metrics_exporter.py +20 -6
- nucliadb/middleware/__init__.py +82 -1
- nucliadb/migrator/datamanager.py +4 -11
- nucliadb/migrator/migrator.py +1 -2
- nucliadb/migrator/models.py +1 -2
- nucliadb/migrator/settings.py +1 -2
- nucliadb/models/internal/augment.py +614 -0
- nucliadb/models/internal/processing.py +19 -19
- nucliadb/openapi.py +2 -2
- nucliadb/purge/__init__.py +3 -8
- nucliadb/purge/orphan_shards.py +1 -2
- nucliadb/reader/__init__.py +5 -0
- nucliadb/reader/api/models.py +6 -13
- nucliadb/reader/api/v1/download.py +59 -38
- nucliadb/reader/api/v1/export_import.py +4 -4
- nucliadb/reader/api/v1/knowledgebox.py +37 -9
- nucliadb/reader/api/v1/learning_config.py +33 -14
- nucliadb/reader/api/v1/resource.py +61 -9
- nucliadb/reader/api/v1/services.py +18 -14
- nucliadb/reader/app.py +3 -1
- nucliadb/reader/reader/notifications.py +1 -2
- nucliadb/search/api/v1/__init__.py +3 -0
- nucliadb/search/api/v1/ask.py +3 -4
- nucliadb/search/api/v1/augment.py +585 -0
- nucliadb/search/api/v1/catalog.py +15 -19
- nucliadb/search/api/v1/find.py +16 -22
- nucliadb/search/api/v1/hydrate.py +328 -0
- nucliadb/search/api/v1/knowledgebox.py +1 -2
- nucliadb/search/api/v1/predict_proxy.py +1 -2
- nucliadb/search/api/v1/resource/ask.py +28 -8
- nucliadb/search/api/v1/resource/ingestion_agents.py +5 -6
- nucliadb/search/api/v1/resource/search.py +9 -11
- nucliadb/search/api/v1/retrieve.py +130 -0
- nucliadb/search/api/v1/search.py +28 -32
- nucliadb/search/api/v1/suggest.py +11 -14
- nucliadb/search/api/v1/summarize.py +1 -2
- nucliadb/search/api/v1/utils.py +2 -2
- nucliadb/search/app.py +3 -2
- nucliadb/search/augmentor/__init__.py +21 -0
- nucliadb/search/augmentor/augmentor.py +232 -0
- nucliadb/search/augmentor/fields.py +704 -0
- nucliadb/search/augmentor/metrics.py +24 -0
- nucliadb/search/augmentor/paragraphs.py +334 -0
- nucliadb/search/augmentor/resources.py +238 -0
- nucliadb/search/augmentor/utils.py +33 -0
- nucliadb/search/lifecycle.py +3 -1
- nucliadb/search/predict.py +33 -19
- nucliadb/search/predict_models.py +8 -9
- nucliadb/search/requesters/utils.py +11 -10
- nucliadb/search/search/cache.py +19 -42
- nucliadb/search/search/chat/ask.py +131 -59
- nucliadb/search/search/chat/exceptions.py +3 -5
- nucliadb/search/search/chat/fetcher.py +201 -0
- nucliadb/search/search/chat/images.py +6 -4
- nucliadb/search/search/chat/old_prompt.py +1375 -0
- nucliadb/search/search/chat/parser.py +510 -0
- nucliadb/search/search/chat/prompt.py +563 -615
- nucliadb/search/search/chat/query.py +453 -32
- nucliadb/search/search/chat/rpc.py +85 -0
- nucliadb/search/search/fetch.py +3 -4
- nucliadb/search/search/filters.py +8 -11
- nucliadb/search/search/find.py +33 -31
- nucliadb/search/search/find_merge.py +124 -331
- nucliadb/search/search/graph_strategy.py +14 -12
- nucliadb/search/search/hydrator/__init__.py +49 -0
- nucliadb/search/search/hydrator/fields.py +217 -0
- nucliadb/search/search/hydrator/images.py +130 -0
- nucliadb/search/search/hydrator/paragraphs.py +323 -0
- nucliadb/search/search/hydrator/resources.py +60 -0
- nucliadb/search/search/ingestion_agents.py +5 -5
- nucliadb/search/search/merge.py +90 -94
- nucliadb/search/search/metrics.py +24 -7
- nucliadb/search/search/paragraphs.py +7 -9
- nucliadb/search/search/predict_proxy.py +44 -18
- nucliadb/search/search/query.py +14 -86
- nucliadb/search/search/query_parser/fetcher.py +51 -82
- nucliadb/search/search/query_parser/models.py +19 -48
- nucliadb/search/search/query_parser/old_filters.py +20 -19
- nucliadb/search/search/query_parser/parsers/ask.py +5 -6
- nucliadb/search/search/query_parser/parsers/catalog.py +7 -11
- nucliadb/search/search/query_parser/parsers/common.py +21 -13
- nucliadb/search/search/query_parser/parsers/find.py +6 -29
- nucliadb/search/search/query_parser/parsers/graph.py +18 -28
- nucliadb/search/search/query_parser/parsers/retrieve.py +207 -0
- nucliadb/search/search/query_parser/parsers/search.py +15 -56
- nucliadb/search/search/query_parser/parsers/unit_retrieval.py +8 -29
- nucliadb/search/search/rank_fusion.py +18 -13
- nucliadb/search/search/rerankers.py +6 -7
- nucliadb/search/search/retrieval.py +300 -0
- nucliadb/search/search/summarize.py +5 -6
- nucliadb/search/search/utils.py +3 -4
- nucliadb/search/settings.py +1 -2
- nucliadb/standalone/api_router.py +1 -1
- nucliadb/standalone/app.py +4 -3
- nucliadb/standalone/auth.py +5 -6
- nucliadb/standalone/lifecycle.py +2 -2
- nucliadb/standalone/run.py +5 -4
- nucliadb/standalone/settings.py +5 -6
- nucliadb/standalone/versions.py +3 -4
- nucliadb/tasks/consumer.py +13 -8
- nucliadb/tasks/models.py +2 -1
- nucliadb/tasks/producer.py +3 -3
- nucliadb/tasks/retries.py +8 -7
- nucliadb/train/api/utils.py +1 -3
- nucliadb/train/api/v1/shards.py +1 -2
- nucliadb/train/api/v1/trainset.py +1 -2
- nucliadb/train/app.py +1 -1
- nucliadb/train/generator.py +4 -4
- nucliadb/train/generators/field_classifier.py +2 -2
- nucliadb/train/generators/field_streaming.py +6 -6
- nucliadb/train/generators/image_classifier.py +2 -2
- nucliadb/train/generators/paragraph_classifier.py +2 -2
- nucliadb/train/generators/paragraph_streaming.py +2 -2
- nucliadb/train/generators/question_answer_streaming.py +2 -2
- nucliadb/train/generators/sentence_classifier.py +4 -10
- nucliadb/train/generators/token_classifier.py +3 -2
- nucliadb/train/generators/utils.py +6 -5
- nucliadb/train/nodes.py +3 -3
- nucliadb/train/resource.py +6 -8
- nucliadb/train/settings.py +3 -4
- nucliadb/train/types.py +11 -11
- nucliadb/train/upload.py +3 -2
- nucliadb/train/uploader.py +1 -2
- nucliadb/train/utils.py +1 -2
- nucliadb/writer/api/v1/export_import.py +4 -1
- nucliadb/writer/api/v1/field.py +15 -14
- nucliadb/writer/api/v1/knowledgebox.py +18 -56
- nucliadb/writer/api/v1/learning_config.py +5 -4
- nucliadb/writer/api/v1/resource.py +9 -20
- nucliadb/writer/api/v1/services.py +10 -132
- nucliadb/writer/api/v1/upload.py +73 -72
- nucliadb/writer/app.py +8 -2
- nucliadb/writer/resource/basic.py +12 -15
- nucliadb/writer/resource/field.py +43 -5
- nucliadb/writer/resource/origin.py +7 -0
- nucliadb/writer/settings.py +2 -3
- nucliadb/writer/tus/__init__.py +2 -3
- nucliadb/writer/tus/azure.py +5 -7
- nucliadb/writer/tus/dm.py +3 -3
- nucliadb/writer/tus/exceptions.py +3 -4
- nucliadb/writer/tus/gcs.py +15 -22
- nucliadb/writer/tus/s3.py +2 -3
- nucliadb/writer/tus/storage.py +3 -3
- {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/METADATA +10 -11
- nucliadb-6.10.0.post5705.dist-info/RECORD +410 -0
- nucliadb/common/datamanagers/entities.py +0 -139
- nucliadb/common/external_index_providers/pinecone.py +0 -894
- nucliadb/ingest/orm/processor/pgcatalog.py +0 -129
- nucliadb/search/search/hydrator.py +0 -197
- nucliadb-6.7.2.post4874.dist-info/RECORD +0 -383
- {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/WHEEL +0 -0
- {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,704 @@
|
|
|
1
|
+
# Copyright (C) 2021 Bosutech XXI S.L.
|
|
2
|
+
#
|
|
3
|
+
# nucliadb is offered under the AGPL v3.0 and as commercial software.
|
|
4
|
+
# For commercial licensing, contact us at info@nuclia.com.
|
|
5
|
+
#
|
|
6
|
+
# AGPL:
|
|
7
|
+
# This program is free software: you can redistribute it and/or modify
|
|
8
|
+
# it under the terms of the GNU Affero General Public License as
|
|
9
|
+
# published by the Free Software Foundation, either version 3 of the
|
|
10
|
+
# License, or (at your option) any later version.
|
|
11
|
+
#
|
|
12
|
+
# This program is distributed in the hope that it will be useful,
|
|
13
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
14
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
15
|
+
# GNU Affero General Public License for more details.
|
|
16
|
+
#
|
|
17
|
+
# You should have received a copy of the GNU Affero General Public License
|
|
18
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
|
+
#
|
|
20
|
+
import asyncio
|
|
21
|
+
from collections import deque
|
|
22
|
+
from collections.abc import AsyncIterator, Sequence
|
|
23
|
+
from typing import Deque, cast
|
|
24
|
+
|
|
25
|
+
from typing_extensions import assert_never
|
|
26
|
+
|
|
27
|
+
from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB, FieldId
|
|
28
|
+
from nucliadb.common.models_utils import from_proto
|
|
29
|
+
from nucliadb.ingest.fields.base import Field
|
|
30
|
+
from nucliadb.ingest.fields.conversation import Conversation
|
|
31
|
+
from nucliadb.ingest.fields.file import File
|
|
32
|
+
from nucliadb.ingest.fields.generic import Generic
|
|
33
|
+
from nucliadb.ingest.fields.link import Link
|
|
34
|
+
from nucliadb.ingest.fields.text import Text
|
|
35
|
+
from nucliadb.ingest.orm.resource import Resource
|
|
36
|
+
from nucliadb.models.internal.augment import (
|
|
37
|
+
AnswerSelector,
|
|
38
|
+
AugmentedConversationField,
|
|
39
|
+
AugmentedConversationMessage,
|
|
40
|
+
AugmentedField,
|
|
41
|
+
AugmentedFileField,
|
|
42
|
+
AugmentedGenericField,
|
|
43
|
+
AugmentedLinkField,
|
|
44
|
+
AugmentedTextField,
|
|
45
|
+
ConversationAnswerOrAfter,
|
|
46
|
+
ConversationAttachments,
|
|
47
|
+
ConversationProp,
|
|
48
|
+
ConversationSelector,
|
|
49
|
+
ConversationText,
|
|
50
|
+
FieldClassificationLabels,
|
|
51
|
+
FieldEntities,
|
|
52
|
+
FieldProp,
|
|
53
|
+
FieldText,
|
|
54
|
+
FieldValue,
|
|
55
|
+
FileProp,
|
|
56
|
+
FileThumbnail,
|
|
57
|
+
FullSelector,
|
|
58
|
+
MessageSelector,
|
|
59
|
+
NeighboursSelector,
|
|
60
|
+
PageSelector,
|
|
61
|
+
WindowSelector,
|
|
62
|
+
)
|
|
63
|
+
from nucliadb.search.augmentor.metrics import augmentor_observer
|
|
64
|
+
from nucliadb.search.augmentor.resources import get_basic
|
|
65
|
+
from nucliadb.search.augmentor.utils import limited_concurrency
|
|
66
|
+
from nucliadb.search.search import cache
|
|
67
|
+
from nucliadb_models.common import FieldTypeName
|
|
68
|
+
from nucliadb_protos import resources_pb2
|
|
69
|
+
from nucliadb_utils.storages.storage import STORAGE_FILE_EXTRACTED
|
|
70
|
+
|
|
71
|
+
# Number of messages to pull after a match in a message
|
|
72
|
+
# The hope here is it will be enough to get the answer to the question.
|
|
73
|
+
CONVERSATION_MESSAGE_CONTEXT_EXPANSION = 15
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
async def augment_fields(
|
|
77
|
+
kbid: str,
|
|
78
|
+
given: list[FieldId],
|
|
79
|
+
select: list[FieldProp | ConversationProp],
|
|
80
|
+
*,
|
|
81
|
+
concurrency_control: asyncio.Semaphore | None = None,
|
|
82
|
+
) -> dict[FieldId, AugmentedField | None]:
|
|
83
|
+
"""Augment a list of fields following an augmentation"""
|
|
84
|
+
|
|
85
|
+
ops = []
|
|
86
|
+
for field_id in given:
|
|
87
|
+
task = asyncio.create_task(
|
|
88
|
+
limited_concurrency(
|
|
89
|
+
augment_field(kbid, field_id, select),
|
|
90
|
+
max_ops=concurrency_control,
|
|
91
|
+
)
|
|
92
|
+
)
|
|
93
|
+
ops.append(task)
|
|
94
|
+
results: list[AugmentedField | None] = await asyncio.gather(*ops)
|
|
95
|
+
|
|
96
|
+
augmented = {}
|
|
97
|
+
for field_id, augmentation in zip(given, results):
|
|
98
|
+
augmented[field_id] = augmentation
|
|
99
|
+
|
|
100
|
+
return augmented
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@augmentor_observer.wrap({"type": "field"})
|
|
104
|
+
async def augment_field(
|
|
105
|
+
kbid: str,
|
|
106
|
+
field_id: FieldId,
|
|
107
|
+
select: Sequence[FieldProp | ConversationProp],
|
|
108
|
+
) -> AugmentedField | None:
|
|
109
|
+
rid = field_id.rid
|
|
110
|
+
resource = await cache.get_resource(kbid, rid)
|
|
111
|
+
if resource is None:
|
|
112
|
+
# skip resources that aren't in the DB
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
field_type_pb = FIELD_TYPE_STR_TO_PB[field_id.type]
|
|
116
|
+
# we must check if field exists or get_field will return an empty field
|
|
117
|
+
# (behaviour thought for ingestion) that we don't want
|
|
118
|
+
if not (await resource.field_exists(field_type_pb, field_id.key)):
|
|
119
|
+
# skip a fields that aren't in the DB
|
|
120
|
+
return None
|
|
121
|
+
field = await resource.get_field(field_id.key, field_id.pb_type)
|
|
122
|
+
|
|
123
|
+
return await db_augment_field(field, field_id, select)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
async def db_augment_field(
|
|
127
|
+
field: Field,
|
|
128
|
+
field_id: FieldId,
|
|
129
|
+
select: Sequence[FieldProp | FileProp | ConversationProp],
|
|
130
|
+
) -> AugmentedField:
|
|
131
|
+
select = dedup_field_select(select)
|
|
132
|
+
|
|
133
|
+
field_type = field_id.type
|
|
134
|
+
|
|
135
|
+
# Note we cast `select` to the specific Union type required by the
|
|
136
|
+
# db_augment_ function. This is safe even if there are props that are not
|
|
137
|
+
# for a specific field, as they will be ignored
|
|
138
|
+
|
|
139
|
+
if field_type == FieldTypeName.TEXT.abbreviation():
|
|
140
|
+
field = cast(Text, field)
|
|
141
|
+
select = cast(list[FieldProp], select)
|
|
142
|
+
return await db_augment_text_field(field, field_id, select)
|
|
143
|
+
|
|
144
|
+
elif field_type == FieldTypeName.FILE.abbreviation():
|
|
145
|
+
field = cast(File, field)
|
|
146
|
+
select = cast(list[FileProp], select)
|
|
147
|
+
return await db_augment_file_field(field, field_id, select)
|
|
148
|
+
|
|
149
|
+
elif field_type == FieldTypeName.LINK.abbreviation():
|
|
150
|
+
field = cast(Link, field)
|
|
151
|
+
select = cast(list[FieldProp], select)
|
|
152
|
+
return await db_augment_link_field(field, field_id, select)
|
|
153
|
+
|
|
154
|
+
elif field_type == FieldTypeName.CONVERSATION.abbreviation():
|
|
155
|
+
field = cast(Conversation, field)
|
|
156
|
+
select = cast(list[ConversationProp], select)
|
|
157
|
+
return await db_augment_conversation_field(field, field_id, select)
|
|
158
|
+
|
|
159
|
+
elif field_type == FieldTypeName.GENERIC.abbreviation():
|
|
160
|
+
field = cast(Generic, field)
|
|
161
|
+
select = cast(list[FieldProp], select)
|
|
162
|
+
return await db_augment_generic_field(field, field_id, select)
|
|
163
|
+
|
|
164
|
+
else: # pragma: no cover
|
|
165
|
+
assert False, f"unknown field type: {field_type}"
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def dedup_field_select(
|
|
169
|
+
select: Sequence[FieldProp | FileProp | ConversationProp],
|
|
170
|
+
) -> Sequence[FieldProp | FileProp | ConversationProp]:
|
|
171
|
+
"""Merge any duplicated property taking the broader augmentation possible."""
|
|
172
|
+
merged = {}
|
|
173
|
+
|
|
174
|
+
# TODO(decoupled-ask): deduplicate conversation props.
|
|
175
|
+
#
|
|
176
|
+
# Note that only conversation properties can be deduplicated (none of the
|
|
177
|
+
# others have any field). However, deduplicating the selector is not
|
|
178
|
+
# possible in many cases, so we do nothing
|
|
179
|
+
unmergeable = []
|
|
180
|
+
|
|
181
|
+
for prop in select:
|
|
182
|
+
if prop.prop not in merged:
|
|
183
|
+
merged[prop.prop] = prop
|
|
184
|
+
|
|
185
|
+
else:
|
|
186
|
+
if isinstance(prop, ConversationText) or isinstance(prop, ConversationAttachments):
|
|
187
|
+
unmergeable.append(prop)
|
|
188
|
+
elif (
|
|
189
|
+
isinstance(prop, FieldText)
|
|
190
|
+
or isinstance(prop, FieldValue)
|
|
191
|
+
or isinstance(prop, FieldClassificationLabels)
|
|
192
|
+
or isinstance(prop, FieldEntities)
|
|
193
|
+
or isinstance(prop, FileThumbnail)
|
|
194
|
+
or isinstance(prop, ConversationAnswerOrAfter)
|
|
195
|
+
):
|
|
196
|
+
# properties without parameters
|
|
197
|
+
pass
|
|
198
|
+
else: # pragma: no cover
|
|
199
|
+
assert_never(prop)
|
|
200
|
+
|
|
201
|
+
return [*merged.values(), *unmergeable]
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
@augmentor_observer.wrap({"type": "db_text_field"})
|
|
205
|
+
async def db_augment_text_field(
|
|
206
|
+
field: Text,
|
|
207
|
+
field_id: FieldId,
|
|
208
|
+
select: Sequence[FieldProp],
|
|
209
|
+
) -> AugmentedTextField:
|
|
210
|
+
augmented = AugmentedTextField(id=field.field_id)
|
|
211
|
+
|
|
212
|
+
for prop in select:
|
|
213
|
+
if isinstance(prop, FieldText):
|
|
214
|
+
augmented.text = await get_field_extracted_text(field_id, field)
|
|
215
|
+
|
|
216
|
+
elif isinstance(prop, FieldClassificationLabels):
|
|
217
|
+
augmented.classification_labels = await classification_labels(field_id, field.resource)
|
|
218
|
+
|
|
219
|
+
elif isinstance(prop, FieldEntities):
|
|
220
|
+
augmented.entities = await field_entities(field_id, field)
|
|
221
|
+
|
|
222
|
+
# text field props
|
|
223
|
+
|
|
224
|
+
elif isinstance(prop, FieldValue):
|
|
225
|
+
db_value = await field.get_value()
|
|
226
|
+
if db_value is None:
|
|
227
|
+
continue
|
|
228
|
+
augmented.value = from_proto.field_text(db_value)
|
|
229
|
+
|
|
230
|
+
else: # pragma: no cover
|
|
231
|
+
assert_never(prop)
|
|
232
|
+
|
|
233
|
+
return augmented
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
@augmentor_observer.wrap({"type": "db_file_field"})
|
|
237
|
+
async def db_augment_file_field(
|
|
238
|
+
field: File,
|
|
239
|
+
field_id: FieldId,
|
|
240
|
+
select: Sequence[FileProp],
|
|
241
|
+
) -> AugmentedFileField:
|
|
242
|
+
augmented = AugmentedFileField(id=field.field_id)
|
|
243
|
+
|
|
244
|
+
for prop in select:
|
|
245
|
+
if isinstance(prop, FieldText):
|
|
246
|
+
augmented.text = await get_field_extracted_text(field_id, field)
|
|
247
|
+
|
|
248
|
+
elif isinstance(prop, FieldClassificationLabels):
|
|
249
|
+
augmented.classification_labels = await classification_labels(field_id, field.resource)
|
|
250
|
+
|
|
251
|
+
elif isinstance(prop, FieldEntities):
|
|
252
|
+
augmented.entities = await field_entities(field_id, field)
|
|
253
|
+
|
|
254
|
+
# file field props
|
|
255
|
+
|
|
256
|
+
elif isinstance(prop, FieldValue):
|
|
257
|
+
db_value = await field.get_value()
|
|
258
|
+
if db_value is None:
|
|
259
|
+
continue
|
|
260
|
+
augmented.value = from_proto.field_file(db_value)
|
|
261
|
+
|
|
262
|
+
elif isinstance(prop, FileThumbnail):
|
|
263
|
+
augmented.thumbnail_path = await get_file_thumbnail_path(field, field_id)
|
|
264
|
+
|
|
265
|
+
else: # pragma: no cover
|
|
266
|
+
assert_never(prop)
|
|
267
|
+
|
|
268
|
+
return augmented
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
@augmentor_observer.wrap({"type": "db_link_field"})
|
|
272
|
+
async def db_augment_link_field(
|
|
273
|
+
field: Link,
|
|
274
|
+
field_id: FieldId,
|
|
275
|
+
select: Sequence[FieldProp],
|
|
276
|
+
) -> AugmentedLinkField:
|
|
277
|
+
augmented = AugmentedLinkField(id=field.field_id)
|
|
278
|
+
|
|
279
|
+
for prop in select:
|
|
280
|
+
if isinstance(prop, FieldText):
|
|
281
|
+
augmented.text = await get_field_extracted_text(field_id, field)
|
|
282
|
+
|
|
283
|
+
elif isinstance(prop, FieldClassificationLabels):
|
|
284
|
+
augmented.classification_labels = await classification_labels(field_id, field.resource)
|
|
285
|
+
|
|
286
|
+
elif isinstance(prop, FieldEntities):
|
|
287
|
+
augmented.entities = await field_entities(field_id, field)
|
|
288
|
+
|
|
289
|
+
# link field props
|
|
290
|
+
|
|
291
|
+
elif isinstance(prop, FieldValue):
|
|
292
|
+
db_value = await field.get_value()
|
|
293
|
+
if db_value is None:
|
|
294
|
+
continue
|
|
295
|
+
augmented.value = from_proto.field_link(db_value)
|
|
296
|
+
|
|
297
|
+
else: # pragma: no cover
|
|
298
|
+
assert_never(prop)
|
|
299
|
+
|
|
300
|
+
return augmented
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
@augmentor_observer.wrap({"type": "db_conversation_field"})
|
|
304
|
+
async def db_augment_conversation_field(
|
|
305
|
+
field: Conversation,
|
|
306
|
+
field_id: FieldId,
|
|
307
|
+
select: list[ConversationProp],
|
|
308
|
+
) -> AugmentedConversationField:
|
|
309
|
+
augmented = AugmentedConversationField(id=field.field_id)
|
|
310
|
+
# map (page, index) -> augmented message. The key uniquely identifies and
|
|
311
|
+
# orders messages
|
|
312
|
+
messages: dict[tuple[int, int], AugmentedConversationMessage] = {}
|
|
313
|
+
|
|
314
|
+
for prop in select:
|
|
315
|
+
if isinstance(prop, FieldText):
|
|
316
|
+
if isinstance(prop, ConversationText):
|
|
317
|
+
selector = prop.selector
|
|
318
|
+
else:
|
|
319
|
+
# when asking for the conversation text without details, we
|
|
320
|
+
# choose the message if a split is provided in the id or the
|
|
321
|
+
# full conversation otherwise
|
|
322
|
+
if field_id.subfield_id is not None:
|
|
323
|
+
selector = MessageSelector()
|
|
324
|
+
else:
|
|
325
|
+
selector = FullSelector()
|
|
326
|
+
|
|
327
|
+
# gather the text from each message matching the selector
|
|
328
|
+
extracted_text_pb = await cache.get_field_extracted_text(field)
|
|
329
|
+
async for page, index, message in conversation_selector(field, field_id, selector):
|
|
330
|
+
augmented_message = messages.setdefault(
|
|
331
|
+
(page, index), AugmentedConversationMessage(ident=message.ident)
|
|
332
|
+
)
|
|
333
|
+
if extracted_text_pb is not None and message.ident in extracted_text_pb.split_text:
|
|
334
|
+
augmented_message.text = extracted_text_pb.split_text[message.ident]
|
|
335
|
+
else:
|
|
336
|
+
augmented_message.text = message.content.text
|
|
337
|
+
|
|
338
|
+
elif isinstance(prop, FieldValue):
|
|
339
|
+
db_value = await field.get_metadata()
|
|
340
|
+
augmented.value = from_proto.field_conversation(db_value)
|
|
341
|
+
|
|
342
|
+
elif isinstance(prop, FieldClassificationLabels):
|
|
343
|
+
augmented.classification_labels = await classification_labels(field_id, field.resource)
|
|
344
|
+
|
|
345
|
+
elif isinstance(prop, FieldEntities):
|
|
346
|
+
augmented.entities = await field_entities(field_id, field)
|
|
347
|
+
|
|
348
|
+
elif isinstance(prop, ConversationAttachments):
|
|
349
|
+
# Each message on a conversation field can have attachments as
|
|
350
|
+
# references to other fields in the same resource.
|
|
351
|
+
#
|
|
352
|
+
# Here, we iterate through all the messages matched by the selector
|
|
353
|
+
# and collect all the attachment references
|
|
354
|
+
async for page, index, message in conversation_selector(field, field_id, prop.selector):
|
|
355
|
+
augmented_message = messages.setdefault(
|
|
356
|
+
(page, index), AugmentedConversationMessage(ident=message.ident)
|
|
357
|
+
)
|
|
358
|
+
augmented_message.attachments = []
|
|
359
|
+
for ref in message.content.attachments_fields:
|
|
360
|
+
field_id = FieldId.from_pb(
|
|
361
|
+
field.uuid, ref.field_type, ref.field_id, ref.split or None
|
|
362
|
+
)
|
|
363
|
+
augmented_message.attachments.append(field_id)
|
|
364
|
+
|
|
365
|
+
elif isinstance(prop, ConversationAnswerOrAfter):
|
|
366
|
+
async for page, index, message in conversation_answer_or_after(field, field_id):
|
|
367
|
+
augmented_message = messages.setdefault(
|
|
368
|
+
(page, index), AugmentedConversationMessage(ident=message.ident)
|
|
369
|
+
)
|
|
370
|
+
if not augmented_message.text:
|
|
371
|
+
augmented_message.text = message.content.text
|
|
372
|
+
|
|
373
|
+
else: # pragma: no cover
|
|
374
|
+
assert_never(prop)
|
|
375
|
+
|
|
376
|
+
if len(messages) > 0:
|
|
377
|
+
augmented.messages = []
|
|
378
|
+
for (_page, _index), m in sorted(messages.items()):
|
|
379
|
+
augmented.messages.append(m)
|
|
380
|
+
|
|
381
|
+
return augmented
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
@augmentor_observer.wrap({"type": "db_generic_field"})
|
|
385
|
+
async def db_augment_generic_field(
|
|
386
|
+
field: Generic,
|
|
387
|
+
field_id: FieldId,
|
|
388
|
+
select: Sequence[FieldProp],
|
|
389
|
+
) -> AugmentedGenericField:
|
|
390
|
+
augmented = AugmentedGenericField(id=field.field_id)
|
|
391
|
+
|
|
392
|
+
for prop in select:
|
|
393
|
+
if isinstance(prop, FieldText):
|
|
394
|
+
augmented.text = await get_field_extracted_text(field_id, field)
|
|
395
|
+
|
|
396
|
+
elif isinstance(prop, FieldClassificationLabels):
|
|
397
|
+
augmented.classification_labels = await classification_labels(field_id, field.resource)
|
|
398
|
+
|
|
399
|
+
elif isinstance(prop, FieldEntities):
|
|
400
|
+
augmented.entities = await field_entities(field_id, field)
|
|
401
|
+
|
|
402
|
+
# generic field props
|
|
403
|
+
|
|
404
|
+
elif isinstance(prop, FieldValue):
|
|
405
|
+
db_value = await field.get_value()
|
|
406
|
+
augmented.value = db_value
|
|
407
|
+
|
|
408
|
+
else: # pragma: no cover
|
|
409
|
+
assert_never(prop)
|
|
410
|
+
|
|
411
|
+
return augmented
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
@augmentor_observer.wrap({"type": "field_text"})
|
|
415
|
+
async def get_field_extracted_text(id: FieldId, field: Field) -> str | None:
|
|
416
|
+
extracted_text_pb = await cache.get_field_extracted_text(field)
|
|
417
|
+
if extracted_text_pb is None: # pragma: no cover
|
|
418
|
+
return None
|
|
419
|
+
|
|
420
|
+
if id.subfield_id:
|
|
421
|
+
return extracted_text_pb.split_text[id.subfield_id]
|
|
422
|
+
else:
|
|
423
|
+
return extracted_text_pb.text
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
async def classification_labels(id: FieldId, resource: Resource) -> dict[str, set[str]] | None:
|
|
427
|
+
basic = await get_basic(resource)
|
|
428
|
+
if basic is None:
|
|
429
|
+
return None
|
|
430
|
+
|
|
431
|
+
labels: dict[str, set[str]] = {}
|
|
432
|
+
for fc in basic.computedmetadata.field_classifications:
|
|
433
|
+
if fc.field.field == id.key and fc.field.field_type == id.pb_type:
|
|
434
|
+
for classification in fc.classifications:
|
|
435
|
+
if classification.cancelled_by_user: # pragma: no cover
|
|
436
|
+
continue
|
|
437
|
+
labels.setdefault(classification.labelset, set()).add(classification.label)
|
|
438
|
+
return labels
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
async def field_entities(id: FieldId, field: Field) -> dict[str, set[str]] | None:
|
|
442
|
+
field_metadata = await field.get_field_metadata()
|
|
443
|
+
if field_metadata is None:
|
|
444
|
+
return None
|
|
445
|
+
|
|
446
|
+
ners: dict[str, set[str]] = {}
|
|
447
|
+
# Data Augmentation + Processor entities
|
|
448
|
+
for (
|
|
449
|
+
data_aumgentation_task_id,
|
|
450
|
+
entities_wrapper,
|
|
451
|
+
) in field_metadata.metadata.entities.items():
|
|
452
|
+
for entity in entities_wrapper.entities:
|
|
453
|
+
ners.setdefault(entity.label, set()).add(entity.text)
|
|
454
|
+
# Legacy processor entities
|
|
455
|
+
# TODO(decoupled-ask): Remove once processor doesn't use this anymore and remove the positions and ner fields from the message
|
|
456
|
+
for token, family in field_metadata.metadata.ner.items():
|
|
457
|
+
ners.setdefault(family, set()).add(token)
|
|
458
|
+
|
|
459
|
+
return ners
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
async def get_file_thumbnail_path(field: File, field_id: FieldId) -> str | None:
|
|
463
|
+
thumbnail = await field.thumbnail()
|
|
464
|
+
if thumbnail is None:
|
|
465
|
+
return None
|
|
466
|
+
|
|
467
|
+
# When ingesting file processed data, we move thumbnails to a owned
|
|
468
|
+
# path. The thumbnail.key must then match this path so we can safely
|
|
469
|
+
# return a path that can be used with the download API to get the
|
|
470
|
+
# actual image
|
|
471
|
+
_expected_prefix = STORAGE_FILE_EXTRACTED.format(
|
|
472
|
+
kbid=field.kbid, uuid=field.uuid, field_type=field_id.type, field=field_id.key, key=""
|
|
473
|
+
)
|
|
474
|
+
assert thumbnail.key.startswith(_expected_prefix), (
|
|
475
|
+
"we use a hardcoded path for file thumbnails and we assume is this"
|
|
476
|
+
)
|
|
477
|
+
thumbnail_path = thumbnail.key.removeprefix(_expected_prefix)
|
|
478
|
+
|
|
479
|
+
return thumbnail_path
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
async def find_conversation_message(
|
|
483
|
+
field: Conversation, ident: str
|
|
484
|
+
) -> tuple[int, int, resources_pb2.Message] | None:
|
|
485
|
+
"""Find a message in the conversation identified by `ident`."""
|
|
486
|
+
conversation_metadata = await field.get_metadata()
|
|
487
|
+
for page in range(1, conversation_metadata.pages + 1):
|
|
488
|
+
conversation = await field.db_get_value(page)
|
|
489
|
+
for idx, message in enumerate(conversation.messages):
|
|
490
|
+
if message.ident == ident:
|
|
491
|
+
return page, idx, message
|
|
492
|
+
return None
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
async def iter_conversation_messages(
|
|
496
|
+
field: Conversation,
|
|
497
|
+
*,
|
|
498
|
+
start_from: tuple[int, int] = (1, 0), # (page, message)
|
|
499
|
+
) -> AsyncIterator[tuple[int, int, resources_pb2.Message]]:
|
|
500
|
+
"""Iterate through the conversation messages starting from an specific page
|
|
501
|
+
and index.
|
|
502
|
+
|
|
503
|
+
"""
|
|
504
|
+
start_page, start_index = start_from
|
|
505
|
+
conversation_metadata = await field.get_metadata()
|
|
506
|
+
for page in range(start_page, conversation_metadata.pages + 1):
|
|
507
|
+
conversation = await field.db_get_value(page)
|
|
508
|
+
for idx, message in enumerate(conversation.messages[start_index:]):
|
|
509
|
+
yield (page, start_index + idx, message)
|
|
510
|
+
# next iteration we want all messages
|
|
511
|
+
start_index = 0
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
async def conversation_answer(
|
|
515
|
+
field: Conversation,
|
|
516
|
+
*,
|
|
517
|
+
start_from: tuple[int, int] = (1, 0), # (page, message)
|
|
518
|
+
) -> tuple[int, int, resources_pb2.Message] | None:
|
|
519
|
+
"""Find the next conversation message of type ANSWER starting from an
|
|
520
|
+
specific page and index.
|
|
521
|
+
|
|
522
|
+
"""
|
|
523
|
+
async for page, index, message in iter_conversation_messages(field, start_from=start_from):
|
|
524
|
+
if message.type == resources_pb2.Message.MessageType.ANSWER:
|
|
525
|
+
return page, index, message
|
|
526
|
+
return None
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
async def conversation_messages_after(
|
|
530
|
+
field: Conversation,
|
|
531
|
+
*,
|
|
532
|
+
start_from: tuple[int, int] = (1, 0), # (page, index)
|
|
533
|
+
limit: int | None = None,
|
|
534
|
+
) -> AsyncIterator[tuple[int, int, resources_pb2.Message]]:
|
|
535
|
+
assert limit is None or limit > 0, "this function can't iterate backwards"
|
|
536
|
+
async for page, index, message in iter_conversation_messages(field, start_from=start_from):
|
|
537
|
+
yield page, index, message
|
|
538
|
+
|
|
539
|
+
if limit is not None:
|
|
540
|
+
limit -= 1
|
|
541
|
+
if limit == 0:
|
|
542
|
+
break
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
async def conversation_selector(
|
|
546
|
+
field: Conversation,
|
|
547
|
+
field_id: FieldId,
|
|
548
|
+
selector: ConversationSelector,
|
|
549
|
+
) -> AsyncIterator[tuple[int, int, resources_pb2.Message]]:
|
|
550
|
+
"""Given a conversation, iterate through the messages matched by a
|
|
551
|
+
selector.
|
|
552
|
+
|
|
553
|
+
"""
|
|
554
|
+
split = field_id.subfield_id
|
|
555
|
+
|
|
556
|
+
if isinstance(selector, MessageSelector):
|
|
557
|
+
if selector.id is None and selector.index is None and split is None:
|
|
558
|
+
return
|
|
559
|
+
|
|
560
|
+
if selector.index is not None:
|
|
561
|
+
metadata = await field.get_metadata()
|
|
562
|
+
if metadata is None:
|
|
563
|
+
# we can't know about pages/messages
|
|
564
|
+
return
|
|
565
|
+
|
|
566
|
+
if isinstance(selector.index, int):
|
|
567
|
+
page = selector.index // metadata.size + 1
|
|
568
|
+
index = selector.index % metadata.size
|
|
569
|
+
|
|
570
|
+
elif isinstance(selector.index, str):
|
|
571
|
+
if selector.index == "first":
|
|
572
|
+
page, index = (1, 0)
|
|
573
|
+
elif selector.index == "last":
|
|
574
|
+
page = metadata.pages
|
|
575
|
+
index = metadata.total % metadata.size - 1
|
|
576
|
+
else: # pragma: no cover
|
|
577
|
+
assert_never(selector.index)
|
|
578
|
+
|
|
579
|
+
else: # pragma: no cover
|
|
580
|
+
assert_never(selector.index)
|
|
581
|
+
|
|
582
|
+
found = None
|
|
583
|
+
async for found in iter_conversation_messages(field, start_from=(page, index)):
|
|
584
|
+
break
|
|
585
|
+
|
|
586
|
+
if found is None:
|
|
587
|
+
return
|
|
588
|
+
|
|
589
|
+
page, index, message = found
|
|
590
|
+
yield page, index, message
|
|
591
|
+
|
|
592
|
+
else:
|
|
593
|
+
# selector.id takes priority over the field id, as it is more specific
|
|
594
|
+
if selector.id is not None:
|
|
595
|
+
split = selector.id
|
|
596
|
+
assert split is not None
|
|
597
|
+
|
|
598
|
+
found = await find_conversation_message(field, split)
|
|
599
|
+
if found is None:
|
|
600
|
+
return
|
|
601
|
+
|
|
602
|
+
page, index, message = found
|
|
603
|
+
yield page, index, message
|
|
604
|
+
|
|
605
|
+
elif isinstance(selector, PageSelector):
|
|
606
|
+
if split is None:
|
|
607
|
+
return
|
|
608
|
+
found = await find_conversation_message(field, split)
|
|
609
|
+
if found is None:
|
|
610
|
+
return
|
|
611
|
+
page, _, _ = found
|
|
612
|
+
|
|
613
|
+
conversation_page = await field.db_get_value(page)
|
|
614
|
+
for index, message in enumerate(conversation_page.messages):
|
|
615
|
+
yield page, index, message
|
|
616
|
+
|
|
617
|
+
elif isinstance(selector, NeighboursSelector):
|
|
618
|
+
selector = cast(NeighboursSelector, selector)
|
|
619
|
+
if split is None:
|
|
620
|
+
return
|
|
621
|
+
found = await find_conversation_message(field, split)
|
|
622
|
+
if found is None:
|
|
623
|
+
return
|
|
624
|
+
page, index, message = found
|
|
625
|
+
yield page, index, message
|
|
626
|
+
|
|
627
|
+
start_from = (page, index + 1)
|
|
628
|
+
async for page, index, message in conversation_messages_after(
|
|
629
|
+
field, start_from=start_from, limit=selector.after
|
|
630
|
+
):
|
|
631
|
+
yield page, index, message
|
|
632
|
+
|
|
633
|
+
elif isinstance(selector, WindowSelector):
|
|
634
|
+
if split is None:
|
|
635
|
+
return
|
|
636
|
+
# Find the position of the `split` message and get the window
|
|
637
|
+
# surrounding it. If there are not enough preceding/following messages,
|
|
638
|
+
# the window won't be centered
|
|
639
|
+
messages: Deque[tuple[int, int, resources_pb2.Message]] = deque(maxlen=selector.size)
|
|
640
|
+
metadata = await field.get_metadata()
|
|
641
|
+
pending = -1
|
|
642
|
+
for page in range(1, metadata.pages + 1):
|
|
643
|
+
conversation_page = await field.db_get_value(page)
|
|
644
|
+
for index, message in enumerate(conversation_page.messages):
|
|
645
|
+
messages.append((page, index, message))
|
|
646
|
+
if pending > 0:
|
|
647
|
+
pending -= 1
|
|
648
|
+
if message.ident == split:
|
|
649
|
+
pending = (selector.size - 1) // 2
|
|
650
|
+
if pending == 0:
|
|
651
|
+
break
|
|
652
|
+
if pending == 0:
|
|
653
|
+
break
|
|
654
|
+
|
|
655
|
+
for page, index, message in messages:
|
|
656
|
+
yield page, index, message
|
|
657
|
+
|
|
658
|
+
elif isinstance(selector, AnswerSelector):
|
|
659
|
+
if split is None:
|
|
660
|
+
return
|
|
661
|
+
found = await find_conversation_message(field, split)
|
|
662
|
+
if found is None:
|
|
663
|
+
return
|
|
664
|
+
page, index, message = found
|
|
665
|
+
|
|
666
|
+
found = await conversation_answer(field, start_from=(page, index))
|
|
667
|
+
if found is not None:
|
|
668
|
+
page, index, answer = found
|
|
669
|
+
yield page, index, answer
|
|
670
|
+
|
|
671
|
+
elif isinstance(selector, FullSelector):
|
|
672
|
+
async for page, index, message in iter_conversation_messages(field):
|
|
673
|
+
yield page, index, message
|
|
674
|
+
|
|
675
|
+
else: # pragma: no cover
|
|
676
|
+
assert_never(selector)
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
async def conversation_answer_or_after(
|
|
680
|
+
field: Conversation, field_id: FieldId
|
|
681
|
+
) -> AsyncIterator[tuple[int, int, resources_pb2.Message]]:
|
|
682
|
+
m: resources_pb2.Message | None = None
|
|
683
|
+
# first search the message in the conversation
|
|
684
|
+
async for page, index, m in conversation_selector(field, field_id, MessageSelector()):
|
|
685
|
+
pass
|
|
686
|
+
|
|
687
|
+
if m is None:
|
|
688
|
+
return
|
|
689
|
+
|
|
690
|
+
if m.type == resources_pb2.Message.MessageType.QUESTION:
|
|
691
|
+
# try to find an answer for this question
|
|
692
|
+
found = await conversation_answer(field, start_from=(page, index + 1))
|
|
693
|
+
if found is None:
|
|
694
|
+
return
|
|
695
|
+
else:
|
|
696
|
+
page, index, answer = found
|
|
697
|
+
yield page, index, answer
|
|
698
|
+
|
|
699
|
+
else:
|
|
700
|
+
# add a bunch of messages after this for more context
|
|
701
|
+
async for page, index, message in conversation_messages_after(
|
|
702
|
+
field, start_from=(page, index + 1), limit=CONVERSATION_MESSAGE_CONTEXT_EXPANSION
|
|
703
|
+
):
|
|
704
|
+
yield page, index, message
|