nucliadb 6.9.1.post5192__py3-none-any.whl → 6.10.0.post5705__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. migrations/0023_backfill_pg_catalog.py +2 -2
  2. migrations/0029_backfill_field_status.py +3 -4
  3. migrations/0032_remove_old_relations.py +2 -3
  4. migrations/0038_backfill_catalog_field_labels.py +2 -2
  5. migrations/0039_backfill_converation_splits_metadata.py +2 -2
  6. migrations/0041_reindex_conversations.py +137 -0
  7. migrations/pg/0010_shards_index.py +34 -0
  8. nucliadb/search/api/v1/resource/utils.py → migrations/pg/0011_catalog_statistics.py +5 -6
  9. migrations/pg/0012_catalog_statistics_undo.py +26 -0
  10. nucliadb/backups/create.py +2 -15
  11. nucliadb/backups/restore.py +4 -15
  12. nucliadb/backups/tasks.py +4 -1
  13. nucliadb/common/back_pressure/cache.py +2 -3
  14. nucliadb/common/back_pressure/materializer.py +7 -13
  15. nucliadb/common/back_pressure/settings.py +6 -6
  16. nucliadb/common/back_pressure/utils.py +1 -0
  17. nucliadb/common/cache.py +9 -9
  18. nucliadb/common/catalog/interface.py +12 -12
  19. nucliadb/common/catalog/pg.py +41 -29
  20. nucliadb/common/catalog/utils.py +3 -3
  21. nucliadb/common/cluster/manager.py +5 -4
  22. nucliadb/common/cluster/rebalance.py +483 -114
  23. nucliadb/common/cluster/rollover.py +25 -9
  24. nucliadb/common/cluster/settings.py +3 -8
  25. nucliadb/common/cluster/utils.py +34 -8
  26. nucliadb/common/context/__init__.py +7 -8
  27. nucliadb/common/context/fastapi.py +1 -2
  28. nucliadb/common/datamanagers/__init__.py +2 -4
  29. nucliadb/common/datamanagers/atomic.py +4 -2
  30. nucliadb/common/datamanagers/cluster.py +1 -2
  31. nucliadb/common/datamanagers/fields.py +3 -4
  32. nucliadb/common/datamanagers/kb.py +6 -6
  33. nucliadb/common/datamanagers/labels.py +2 -3
  34. nucliadb/common/datamanagers/resources.py +10 -33
  35. nucliadb/common/datamanagers/rollover.py +5 -7
  36. nucliadb/common/datamanagers/search_configurations.py +1 -2
  37. nucliadb/common/datamanagers/synonyms.py +1 -2
  38. nucliadb/common/datamanagers/utils.py +4 -4
  39. nucliadb/common/datamanagers/vectorsets.py +4 -4
  40. nucliadb/common/external_index_providers/base.py +32 -5
  41. nucliadb/common/external_index_providers/manager.py +4 -5
  42. nucliadb/common/filter_expression.py +128 -40
  43. nucliadb/common/http_clients/processing.py +12 -23
  44. nucliadb/common/ids.py +6 -4
  45. nucliadb/common/locking.py +1 -2
  46. nucliadb/common/maindb/driver.py +9 -8
  47. nucliadb/common/maindb/local.py +5 -5
  48. nucliadb/common/maindb/pg.py +9 -8
  49. nucliadb/common/nidx.py +3 -4
  50. nucliadb/export_import/datamanager.py +4 -3
  51. nucliadb/export_import/exporter.py +11 -19
  52. nucliadb/export_import/importer.py +13 -6
  53. nucliadb/export_import/tasks.py +2 -0
  54. nucliadb/export_import/utils.py +6 -18
  55. nucliadb/health.py +2 -2
  56. nucliadb/ingest/app.py +8 -8
  57. nucliadb/ingest/consumer/consumer.py +8 -10
  58. nucliadb/ingest/consumer/pull.py +3 -8
  59. nucliadb/ingest/consumer/service.py +3 -3
  60. nucliadb/ingest/consumer/utils.py +1 -1
  61. nucliadb/ingest/fields/base.py +28 -49
  62. nucliadb/ingest/fields/conversation.py +12 -12
  63. nucliadb/ingest/fields/exceptions.py +1 -2
  64. nucliadb/ingest/fields/file.py +22 -8
  65. nucliadb/ingest/fields/link.py +7 -7
  66. nucliadb/ingest/fields/text.py +2 -3
  67. nucliadb/ingest/orm/brain_v2.py +78 -64
  68. nucliadb/ingest/orm/broker_message.py +2 -4
  69. nucliadb/ingest/orm/entities.py +10 -209
  70. nucliadb/ingest/orm/index_message.py +4 -4
  71. nucliadb/ingest/orm/knowledgebox.py +18 -27
  72. nucliadb/ingest/orm/processor/auditing.py +1 -3
  73. nucliadb/ingest/orm/processor/data_augmentation.py +1 -2
  74. nucliadb/ingest/orm/processor/processor.py +27 -27
  75. nucliadb/ingest/orm/processor/sequence_manager.py +1 -2
  76. nucliadb/ingest/orm/resource.py +72 -70
  77. nucliadb/ingest/orm/utils.py +1 -1
  78. nucliadb/ingest/processing.py +17 -17
  79. nucliadb/ingest/serialize.py +202 -145
  80. nucliadb/ingest/service/writer.py +3 -109
  81. nucliadb/ingest/settings.py +3 -4
  82. nucliadb/ingest/utils.py +1 -2
  83. nucliadb/learning_proxy.py +11 -11
  84. nucliadb/metrics_exporter.py +5 -4
  85. nucliadb/middleware/__init__.py +82 -1
  86. nucliadb/migrator/datamanager.py +3 -4
  87. nucliadb/migrator/migrator.py +1 -2
  88. nucliadb/migrator/models.py +1 -2
  89. nucliadb/migrator/settings.py +1 -2
  90. nucliadb/models/internal/augment.py +614 -0
  91. nucliadb/models/internal/processing.py +19 -19
  92. nucliadb/openapi.py +2 -2
  93. nucliadb/purge/__init__.py +3 -8
  94. nucliadb/purge/orphan_shards.py +1 -2
  95. nucliadb/reader/__init__.py +5 -0
  96. nucliadb/reader/api/models.py +6 -13
  97. nucliadb/reader/api/v1/download.py +59 -38
  98. nucliadb/reader/api/v1/export_import.py +4 -4
  99. nucliadb/reader/api/v1/learning_config.py +24 -4
  100. nucliadb/reader/api/v1/resource.py +61 -9
  101. nucliadb/reader/api/v1/services.py +18 -14
  102. nucliadb/reader/app.py +3 -1
  103. nucliadb/reader/reader/notifications.py +1 -2
  104. nucliadb/search/api/v1/__init__.py +2 -0
  105. nucliadb/search/api/v1/ask.py +3 -4
  106. nucliadb/search/api/v1/augment.py +585 -0
  107. nucliadb/search/api/v1/catalog.py +11 -15
  108. nucliadb/search/api/v1/find.py +16 -22
  109. nucliadb/search/api/v1/hydrate.py +25 -25
  110. nucliadb/search/api/v1/knowledgebox.py +1 -2
  111. nucliadb/search/api/v1/predict_proxy.py +1 -2
  112. nucliadb/search/api/v1/resource/ask.py +7 -7
  113. nucliadb/search/api/v1/resource/ingestion_agents.py +5 -6
  114. nucliadb/search/api/v1/resource/search.py +9 -11
  115. nucliadb/search/api/v1/retrieve.py +130 -0
  116. nucliadb/search/api/v1/search.py +28 -32
  117. nucliadb/search/api/v1/suggest.py +11 -14
  118. nucliadb/search/api/v1/summarize.py +1 -2
  119. nucliadb/search/api/v1/utils.py +2 -2
  120. nucliadb/search/app.py +3 -2
  121. nucliadb/search/augmentor/__init__.py +21 -0
  122. nucliadb/search/augmentor/augmentor.py +232 -0
  123. nucliadb/search/augmentor/fields.py +704 -0
  124. nucliadb/search/augmentor/metrics.py +24 -0
  125. nucliadb/search/augmentor/paragraphs.py +334 -0
  126. nucliadb/search/augmentor/resources.py +238 -0
  127. nucliadb/search/augmentor/utils.py +33 -0
  128. nucliadb/search/lifecycle.py +3 -1
  129. nucliadb/search/predict.py +24 -17
  130. nucliadb/search/predict_models.py +8 -9
  131. nucliadb/search/requesters/utils.py +11 -10
  132. nucliadb/search/search/cache.py +19 -23
  133. nucliadb/search/search/chat/ask.py +88 -59
  134. nucliadb/search/search/chat/exceptions.py +3 -5
  135. nucliadb/search/search/chat/fetcher.py +201 -0
  136. nucliadb/search/search/chat/images.py +6 -4
  137. nucliadb/search/search/chat/old_prompt.py +1375 -0
  138. nucliadb/search/search/chat/parser.py +510 -0
  139. nucliadb/search/search/chat/prompt.py +563 -615
  140. nucliadb/search/search/chat/query.py +449 -36
  141. nucliadb/search/search/chat/rpc.py +85 -0
  142. nucliadb/search/search/fetch.py +3 -4
  143. nucliadb/search/search/filters.py +8 -11
  144. nucliadb/search/search/find.py +33 -31
  145. nucliadb/search/search/find_merge.py +124 -331
  146. nucliadb/search/search/graph_strategy.py +14 -12
  147. nucliadb/search/search/hydrator/__init__.py +3 -152
  148. nucliadb/search/search/hydrator/fields.py +92 -50
  149. nucliadb/search/search/hydrator/images.py +7 -7
  150. nucliadb/search/search/hydrator/paragraphs.py +42 -26
  151. nucliadb/search/search/hydrator/resources.py +20 -16
  152. nucliadb/search/search/ingestion_agents.py +5 -5
  153. nucliadb/search/search/merge.py +90 -94
  154. nucliadb/search/search/metrics.py +10 -9
  155. nucliadb/search/search/paragraphs.py +7 -9
  156. nucliadb/search/search/predict_proxy.py +13 -9
  157. nucliadb/search/search/query.py +14 -86
  158. nucliadb/search/search/query_parser/fetcher.py +51 -82
  159. nucliadb/search/search/query_parser/models.py +19 -20
  160. nucliadb/search/search/query_parser/old_filters.py +20 -19
  161. nucliadb/search/search/query_parser/parsers/ask.py +4 -5
  162. nucliadb/search/search/query_parser/parsers/catalog.py +5 -6
  163. nucliadb/search/search/query_parser/parsers/common.py +5 -6
  164. nucliadb/search/search/query_parser/parsers/find.py +6 -26
  165. nucliadb/search/search/query_parser/parsers/graph.py +13 -23
  166. nucliadb/search/search/query_parser/parsers/retrieve.py +207 -0
  167. nucliadb/search/search/query_parser/parsers/search.py +15 -53
  168. nucliadb/search/search/query_parser/parsers/unit_retrieval.py +8 -29
  169. nucliadb/search/search/rank_fusion.py +18 -13
  170. nucliadb/search/search/rerankers.py +5 -6
  171. nucliadb/search/search/retrieval.py +300 -0
  172. nucliadb/search/search/summarize.py +5 -6
  173. nucliadb/search/search/utils.py +3 -4
  174. nucliadb/search/settings.py +1 -2
  175. nucliadb/standalone/api_router.py +1 -1
  176. nucliadb/standalone/app.py +4 -3
  177. nucliadb/standalone/auth.py +5 -6
  178. nucliadb/standalone/lifecycle.py +2 -2
  179. nucliadb/standalone/run.py +2 -4
  180. nucliadb/standalone/settings.py +5 -6
  181. nucliadb/standalone/versions.py +3 -4
  182. nucliadb/tasks/consumer.py +13 -8
  183. nucliadb/tasks/models.py +2 -1
  184. nucliadb/tasks/producer.py +3 -3
  185. nucliadb/tasks/retries.py +8 -7
  186. nucliadb/train/api/utils.py +1 -3
  187. nucliadb/train/api/v1/shards.py +1 -2
  188. nucliadb/train/api/v1/trainset.py +1 -2
  189. nucliadb/train/app.py +1 -1
  190. nucliadb/train/generator.py +4 -4
  191. nucliadb/train/generators/field_classifier.py +2 -2
  192. nucliadb/train/generators/field_streaming.py +6 -6
  193. nucliadb/train/generators/image_classifier.py +2 -2
  194. nucliadb/train/generators/paragraph_classifier.py +2 -2
  195. nucliadb/train/generators/paragraph_streaming.py +2 -2
  196. nucliadb/train/generators/question_answer_streaming.py +2 -2
  197. nucliadb/train/generators/sentence_classifier.py +2 -2
  198. nucliadb/train/generators/token_classifier.py +3 -2
  199. nucliadb/train/generators/utils.py +6 -5
  200. nucliadb/train/nodes.py +3 -3
  201. nucliadb/train/resource.py +6 -8
  202. nucliadb/train/settings.py +3 -4
  203. nucliadb/train/types.py +11 -11
  204. nucliadb/train/upload.py +3 -2
  205. nucliadb/train/uploader.py +1 -2
  206. nucliadb/train/utils.py +1 -2
  207. nucliadb/writer/api/v1/export_import.py +4 -1
  208. nucliadb/writer/api/v1/field.py +7 -11
  209. nucliadb/writer/api/v1/knowledgebox.py +3 -4
  210. nucliadb/writer/api/v1/resource.py +9 -20
  211. nucliadb/writer/api/v1/services.py +10 -132
  212. nucliadb/writer/api/v1/upload.py +73 -72
  213. nucliadb/writer/app.py +8 -2
  214. nucliadb/writer/resource/basic.py +12 -15
  215. nucliadb/writer/resource/field.py +7 -5
  216. nucliadb/writer/resource/origin.py +7 -0
  217. nucliadb/writer/settings.py +2 -3
  218. nucliadb/writer/tus/__init__.py +2 -3
  219. nucliadb/writer/tus/azure.py +1 -3
  220. nucliadb/writer/tus/dm.py +3 -3
  221. nucliadb/writer/tus/exceptions.py +3 -4
  222. nucliadb/writer/tus/gcs.py +5 -6
  223. nucliadb/writer/tus/s3.py +2 -3
  224. nucliadb/writer/tus/storage.py +3 -3
  225. {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/METADATA +9 -10
  226. nucliadb-6.10.0.post5705.dist-info/RECORD +410 -0
  227. nucliadb/common/datamanagers/entities.py +0 -139
  228. nucliadb-6.9.1.post5192.dist-info/RECORD +0 -392
  229. {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/WHEEL +0 -0
  230. {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/entry_points.txt +0 -0
  231. {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,704 @@
1
+ # Copyright (C) 2021 Bosutech XXI S.L.
2
+ #
3
+ # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
+ # For commercial licensing, contact us at info@nuclia.com.
5
+ #
6
+ # AGPL:
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU Affero General Public License as
9
+ # published by the Free Software Foundation, either version 3 of the
10
+ # License, or (at your option) any later version.
11
+ #
12
+ # This program is distributed in the hope that it will be useful,
13
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
+ # GNU Affero General Public License for more details.
16
+ #
17
+ # You should have received a copy of the GNU Affero General Public License
18
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
+ #
20
+ import asyncio
21
+ from collections import deque
22
+ from collections.abc import AsyncIterator, Sequence
23
+ from typing import Deque, cast
24
+
25
+ from typing_extensions import assert_never
26
+
27
+ from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB, FieldId
28
+ from nucliadb.common.models_utils import from_proto
29
+ from nucliadb.ingest.fields.base import Field
30
+ from nucliadb.ingest.fields.conversation import Conversation
31
+ from nucliadb.ingest.fields.file import File
32
+ from nucliadb.ingest.fields.generic import Generic
33
+ from nucliadb.ingest.fields.link import Link
34
+ from nucliadb.ingest.fields.text import Text
35
+ from nucliadb.ingest.orm.resource import Resource
36
+ from nucliadb.models.internal.augment import (
37
+ AnswerSelector,
38
+ AugmentedConversationField,
39
+ AugmentedConversationMessage,
40
+ AugmentedField,
41
+ AugmentedFileField,
42
+ AugmentedGenericField,
43
+ AugmentedLinkField,
44
+ AugmentedTextField,
45
+ ConversationAnswerOrAfter,
46
+ ConversationAttachments,
47
+ ConversationProp,
48
+ ConversationSelector,
49
+ ConversationText,
50
+ FieldClassificationLabels,
51
+ FieldEntities,
52
+ FieldProp,
53
+ FieldText,
54
+ FieldValue,
55
+ FileProp,
56
+ FileThumbnail,
57
+ FullSelector,
58
+ MessageSelector,
59
+ NeighboursSelector,
60
+ PageSelector,
61
+ WindowSelector,
62
+ )
63
+ from nucliadb.search.augmentor.metrics import augmentor_observer
64
+ from nucliadb.search.augmentor.resources import get_basic
65
+ from nucliadb.search.augmentor.utils import limited_concurrency
66
+ from nucliadb.search.search import cache
67
+ from nucliadb_models.common import FieldTypeName
68
+ from nucliadb_protos import resources_pb2
69
+ from nucliadb_utils.storages.storage import STORAGE_FILE_EXTRACTED
70
+
71
+ # Number of messages to pull after a match in a message
72
+ # The hope here is it will be enough to get the answer to the question.
73
+ CONVERSATION_MESSAGE_CONTEXT_EXPANSION = 15
74
+
75
+
76
+ async def augment_fields(
77
+ kbid: str,
78
+ given: list[FieldId],
79
+ select: list[FieldProp | ConversationProp],
80
+ *,
81
+ concurrency_control: asyncio.Semaphore | None = None,
82
+ ) -> dict[FieldId, AugmentedField | None]:
83
+ """Augment a list of fields following an augmentation"""
84
+
85
+ ops = []
86
+ for field_id in given:
87
+ task = asyncio.create_task(
88
+ limited_concurrency(
89
+ augment_field(kbid, field_id, select),
90
+ max_ops=concurrency_control,
91
+ )
92
+ )
93
+ ops.append(task)
94
+ results: list[AugmentedField | None] = await asyncio.gather(*ops)
95
+
96
+ augmented = {}
97
+ for field_id, augmentation in zip(given, results):
98
+ augmented[field_id] = augmentation
99
+
100
+ return augmented
101
+
102
+
103
+ @augmentor_observer.wrap({"type": "field"})
104
+ async def augment_field(
105
+ kbid: str,
106
+ field_id: FieldId,
107
+ select: Sequence[FieldProp | ConversationProp],
108
+ ) -> AugmentedField | None:
109
+ rid = field_id.rid
110
+ resource = await cache.get_resource(kbid, rid)
111
+ if resource is None:
112
+ # skip resources that aren't in the DB
113
+ return None
114
+
115
+ field_type_pb = FIELD_TYPE_STR_TO_PB[field_id.type]
116
+ # we must check if field exists or get_field will return an empty field
117
+ # (behaviour thought for ingestion) that we don't want
118
+ if not (await resource.field_exists(field_type_pb, field_id.key)):
119
+ # skip a fields that aren't in the DB
120
+ return None
121
+ field = await resource.get_field(field_id.key, field_id.pb_type)
122
+
123
+ return await db_augment_field(field, field_id, select)
124
+
125
+
126
+ async def db_augment_field(
127
+ field: Field,
128
+ field_id: FieldId,
129
+ select: Sequence[FieldProp | FileProp | ConversationProp],
130
+ ) -> AugmentedField:
131
+ select = dedup_field_select(select)
132
+
133
+ field_type = field_id.type
134
+
135
+ # Note we cast `select` to the specific Union type required by the
136
+ # db_augment_ function. This is safe even if there are props that are not
137
+ # for a specific field, as they will be ignored
138
+
139
+ if field_type == FieldTypeName.TEXT.abbreviation():
140
+ field = cast(Text, field)
141
+ select = cast(list[FieldProp], select)
142
+ return await db_augment_text_field(field, field_id, select)
143
+
144
+ elif field_type == FieldTypeName.FILE.abbreviation():
145
+ field = cast(File, field)
146
+ select = cast(list[FileProp], select)
147
+ return await db_augment_file_field(field, field_id, select)
148
+
149
+ elif field_type == FieldTypeName.LINK.abbreviation():
150
+ field = cast(Link, field)
151
+ select = cast(list[FieldProp], select)
152
+ return await db_augment_link_field(field, field_id, select)
153
+
154
+ elif field_type == FieldTypeName.CONVERSATION.abbreviation():
155
+ field = cast(Conversation, field)
156
+ select = cast(list[ConversationProp], select)
157
+ return await db_augment_conversation_field(field, field_id, select)
158
+
159
+ elif field_type == FieldTypeName.GENERIC.abbreviation():
160
+ field = cast(Generic, field)
161
+ select = cast(list[FieldProp], select)
162
+ return await db_augment_generic_field(field, field_id, select)
163
+
164
+ else: # pragma: no cover
165
+ assert False, f"unknown field type: {field_type}"
166
+
167
+
168
+ def dedup_field_select(
169
+ select: Sequence[FieldProp | FileProp | ConversationProp],
170
+ ) -> Sequence[FieldProp | FileProp | ConversationProp]:
171
+ """Merge any duplicated property taking the broader augmentation possible."""
172
+ merged = {}
173
+
174
+ # TODO(decoupled-ask): deduplicate conversation props.
175
+ #
176
+ # Note that only conversation properties can be deduplicated (none of the
177
+ # others have any field). However, deduplicating the selector is not
178
+ # possible in many cases, so we do nothing
179
+ unmergeable = []
180
+
181
+ for prop in select:
182
+ if prop.prop not in merged:
183
+ merged[prop.prop] = prop
184
+
185
+ else:
186
+ if isinstance(prop, ConversationText) or isinstance(prop, ConversationAttachments):
187
+ unmergeable.append(prop)
188
+ elif (
189
+ isinstance(prop, FieldText)
190
+ or isinstance(prop, FieldValue)
191
+ or isinstance(prop, FieldClassificationLabels)
192
+ or isinstance(prop, FieldEntities)
193
+ or isinstance(prop, FileThumbnail)
194
+ or isinstance(prop, ConversationAnswerOrAfter)
195
+ ):
196
+ # properties without parameters
197
+ pass
198
+ else: # pragma: no cover
199
+ assert_never(prop)
200
+
201
+ return [*merged.values(), *unmergeable]
202
+
203
+
204
+ @augmentor_observer.wrap({"type": "db_text_field"})
205
+ async def db_augment_text_field(
206
+ field: Text,
207
+ field_id: FieldId,
208
+ select: Sequence[FieldProp],
209
+ ) -> AugmentedTextField:
210
+ augmented = AugmentedTextField(id=field.field_id)
211
+
212
+ for prop in select:
213
+ if isinstance(prop, FieldText):
214
+ augmented.text = await get_field_extracted_text(field_id, field)
215
+
216
+ elif isinstance(prop, FieldClassificationLabels):
217
+ augmented.classification_labels = await classification_labels(field_id, field.resource)
218
+
219
+ elif isinstance(prop, FieldEntities):
220
+ augmented.entities = await field_entities(field_id, field)
221
+
222
+ # text field props
223
+
224
+ elif isinstance(prop, FieldValue):
225
+ db_value = await field.get_value()
226
+ if db_value is None:
227
+ continue
228
+ augmented.value = from_proto.field_text(db_value)
229
+
230
+ else: # pragma: no cover
231
+ assert_never(prop)
232
+
233
+ return augmented
234
+
235
+
236
+ @augmentor_observer.wrap({"type": "db_file_field"})
237
+ async def db_augment_file_field(
238
+ field: File,
239
+ field_id: FieldId,
240
+ select: Sequence[FileProp],
241
+ ) -> AugmentedFileField:
242
+ augmented = AugmentedFileField(id=field.field_id)
243
+
244
+ for prop in select:
245
+ if isinstance(prop, FieldText):
246
+ augmented.text = await get_field_extracted_text(field_id, field)
247
+
248
+ elif isinstance(prop, FieldClassificationLabels):
249
+ augmented.classification_labels = await classification_labels(field_id, field.resource)
250
+
251
+ elif isinstance(prop, FieldEntities):
252
+ augmented.entities = await field_entities(field_id, field)
253
+
254
+ # file field props
255
+
256
+ elif isinstance(prop, FieldValue):
257
+ db_value = await field.get_value()
258
+ if db_value is None:
259
+ continue
260
+ augmented.value = from_proto.field_file(db_value)
261
+
262
+ elif isinstance(prop, FileThumbnail):
263
+ augmented.thumbnail_path = await get_file_thumbnail_path(field, field_id)
264
+
265
+ else: # pragma: no cover
266
+ assert_never(prop)
267
+
268
+ return augmented
269
+
270
+
271
+ @augmentor_observer.wrap({"type": "db_link_field"})
272
+ async def db_augment_link_field(
273
+ field: Link,
274
+ field_id: FieldId,
275
+ select: Sequence[FieldProp],
276
+ ) -> AugmentedLinkField:
277
+ augmented = AugmentedLinkField(id=field.field_id)
278
+
279
+ for prop in select:
280
+ if isinstance(prop, FieldText):
281
+ augmented.text = await get_field_extracted_text(field_id, field)
282
+
283
+ elif isinstance(prop, FieldClassificationLabels):
284
+ augmented.classification_labels = await classification_labels(field_id, field.resource)
285
+
286
+ elif isinstance(prop, FieldEntities):
287
+ augmented.entities = await field_entities(field_id, field)
288
+
289
+ # link field props
290
+
291
+ elif isinstance(prop, FieldValue):
292
+ db_value = await field.get_value()
293
+ if db_value is None:
294
+ continue
295
+ augmented.value = from_proto.field_link(db_value)
296
+
297
+ else: # pragma: no cover
298
+ assert_never(prop)
299
+
300
+ return augmented
301
+
302
+
303
+ @augmentor_observer.wrap({"type": "db_conversation_field"})
304
+ async def db_augment_conversation_field(
305
+ field: Conversation,
306
+ field_id: FieldId,
307
+ select: list[ConversationProp],
308
+ ) -> AugmentedConversationField:
309
+ augmented = AugmentedConversationField(id=field.field_id)
310
+ # map (page, index) -> augmented message. The key uniquely identifies and
311
+ # orders messages
312
+ messages: dict[tuple[int, int], AugmentedConversationMessage] = {}
313
+
314
+ for prop in select:
315
+ if isinstance(prop, FieldText):
316
+ if isinstance(prop, ConversationText):
317
+ selector = prop.selector
318
+ else:
319
+ # when asking for the conversation text without details, we
320
+ # choose the message if a split is provided in the id or the
321
+ # full conversation otherwise
322
+ if field_id.subfield_id is not None:
323
+ selector = MessageSelector()
324
+ else:
325
+ selector = FullSelector()
326
+
327
+ # gather the text from each message matching the selector
328
+ extracted_text_pb = await cache.get_field_extracted_text(field)
329
+ async for page, index, message in conversation_selector(field, field_id, selector):
330
+ augmented_message = messages.setdefault(
331
+ (page, index), AugmentedConversationMessage(ident=message.ident)
332
+ )
333
+ if extracted_text_pb is not None and message.ident in extracted_text_pb.split_text:
334
+ augmented_message.text = extracted_text_pb.split_text[message.ident]
335
+ else:
336
+ augmented_message.text = message.content.text
337
+
338
+ elif isinstance(prop, FieldValue):
339
+ db_value = await field.get_metadata()
340
+ augmented.value = from_proto.field_conversation(db_value)
341
+
342
+ elif isinstance(prop, FieldClassificationLabels):
343
+ augmented.classification_labels = await classification_labels(field_id, field.resource)
344
+
345
+ elif isinstance(prop, FieldEntities):
346
+ augmented.entities = await field_entities(field_id, field)
347
+
348
+ elif isinstance(prop, ConversationAttachments):
349
+ # Each message on a conversation field can have attachments as
350
+ # references to other fields in the same resource.
351
+ #
352
+ # Here, we iterate through all the messages matched by the selector
353
+ # and collect all the attachment references
354
+ async for page, index, message in conversation_selector(field, field_id, prop.selector):
355
+ augmented_message = messages.setdefault(
356
+ (page, index), AugmentedConversationMessage(ident=message.ident)
357
+ )
358
+ augmented_message.attachments = []
359
+ for ref in message.content.attachments_fields:
360
+ field_id = FieldId.from_pb(
361
+ field.uuid, ref.field_type, ref.field_id, ref.split or None
362
+ )
363
+ augmented_message.attachments.append(field_id)
364
+
365
+ elif isinstance(prop, ConversationAnswerOrAfter):
366
+ async for page, index, message in conversation_answer_or_after(field, field_id):
367
+ augmented_message = messages.setdefault(
368
+ (page, index), AugmentedConversationMessage(ident=message.ident)
369
+ )
370
+ if not augmented_message.text:
371
+ augmented_message.text = message.content.text
372
+
373
+ else: # pragma: no cover
374
+ assert_never(prop)
375
+
376
+ if len(messages) > 0:
377
+ augmented.messages = []
378
+ for (_page, _index), m in sorted(messages.items()):
379
+ augmented.messages.append(m)
380
+
381
+ return augmented
382
+
383
+
384
+ @augmentor_observer.wrap({"type": "db_generic_field"})
385
+ async def db_augment_generic_field(
386
+ field: Generic,
387
+ field_id: FieldId,
388
+ select: Sequence[FieldProp],
389
+ ) -> AugmentedGenericField:
390
+ augmented = AugmentedGenericField(id=field.field_id)
391
+
392
+ for prop in select:
393
+ if isinstance(prop, FieldText):
394
+ augmented.text = await get_field_extracted_text(field_id, field)
395
+
396
+ elif isinstance(prop, FieldClassificationLabels):
397
+ augmented.classification_labels = await classification_labels(field_id, field.resource)
398
+
399
+ elif isinstance(prop, FieldEntities):
400
+ augmented.entities = await field_entities(field_id, field)
401
+
402
+ # generic field props
403
+
404
+ elif isinstance(prop, FieldValue):
405
+ db_value = await field.get_value()
406
+ augmented.value = db_value
407
+
408
+ else: # pragma: no cover
409
+ assert_never(prop)
410
+
411
+ return augmented
412
+
413
+
414
+ @augmentor_observer.wrap({"type": "field_text"})
415
+ async def get_field_extracted_text(id: FieldId, field: Field) -> str | None:
416
+ extracted_text_pb = await cache.get_field_extracted_text(field)
417
+ if extracted_text_pb is None: # pragma: no cover
418
+ return None
419
+
420
+ if id.subfield_id:
421
+ return extracted_text_pb.split_text[id.subfield_id]
422
+ else:
423
+ return extracted_text_pb.text
424
+
425
+
426
+ async def classification_labels(id: FieldId, resource: Resource) -> dict[str, set[str]] | None:
427
+ basic = await get_basic(resource)
428
+ if basic is None:
429
+ return None
430
+
431
+ labels: dict[str, set[str]] = {}
432
+ for fc in basic.computedmetadata.field_classifications:
433
+ if fc.field.field == id.key and fc.field.field_type == id.pb_type:
434
+ for classification in fc.classifications:
435
+ if classification.cancelled_by_user: # pragma: no cover
436
+ continue
437
+ labels.setdefault(classification.labelset, set()).add(classification.label)
438
+ return labels
439
+
440
+
441
+ async def field_entities(id: FieldId, field: Field) -> dict[str, set[str]] | None:
442
+ field_metadata = await field.get_field_metadata()
443
+ if field_metadata is None:
444
+ return None
445
+
446
+ ners: dict[str, set[str]] = {}
447
+ # Data Augmentation + Processor entities
448
+ for (
449
+ data_aumgentation_task_id,
450
+ entities_wrapper,
451
+ ) in field_metadata.metadata.entities.items():
452
+ for entity in entities_wrapper.entities:
453
+ ners.setdefault(entity.label, set()).add(entity.text)
454
+ # Legacy processor entities
455
+ # TODO(decoupled-ask): Remove once processor doesn't use this anymore and remove the positions and ner fields from the message
456
+ for token, family in field_metadata.metadata.ner.items():
457
+ ners.setdefault(family, set()).add(token)
458
+
459
+ return ners
460
+
461
+
462
+ async def get_file_thumbnail_path(field: File, field_id: FieldId) -> str | None:
463
+ thumbnail = await field.thumbnail()
464
+ if thumbnail is None:
465
+ return None
466
+
467
+ # When ingesting file processed data, we move thumbnails to a owned
468
+ # path. The thumbnail.key must then match this path so we can safely
469
+ # return a path that can be used with the download API to get the
470
+ # actual image
471
+ _expected_prefix = STORAGE_FILE_EXTRACTED.format(
472
+ kbid=field.kbid, uuid=field.uuid, field_type=field_id.type, field=field_id.key, key=""
473
+ )
474
+ assert thumbnail.key.startswith(_expected_prefix), (
475
+ "we use a hardcoded path for file thumbnails and we assume is this"
476
+ )
477
+ thumbnail_path = thumbnail.key.removeprefix(_expected_prefix)
478
+
479
+ return thumbnail_path
480
+
481
+
482
+ async def find_conversation_message(
483
+ field: Conversation, ident: str
484
+ ) -> tuple[int, int, resources_pb2.Message] | None:
485
+ """Find a message in the conversation identified by `ident`."""
486
+ conversation_metadata = await field.get_metadata()
487
+ for page in range(1, conversation_metadata.pages + 1):
488
+ conversation = await field.db_get_value(page)
489
+ for idx, message in enumerate(conversation.messages):
490
+ if message.ident == ident:
491
+ return page, idx, message
492
+ return None
493
+
494
+
495
+ async def iter_conversation_messages(
496
+ field: Conversation,
497
+ *,
498
+ start_from: tuple[int, int] = (1, 0), # (page, message)
499
+ ) -> AsyncIterator[tuple[int, int, resources_pb2.Message]]:
500
+ """Iterate through the conversation messages starting from an specific page
501
+ and index.
502
+
503
+ """
504
+ start_page, start_index = start_from
505
+ conversation_metadata = await field.get_metadata()
506
+ for page in range(start_page, conversation_metadata.pages + 1):
507
+ conversation = await field.db_get_value(page)
508
+ for idx, message in enumerate(conversation.messages[start_index:]):
509
+ yield (page, start_index + idx, message)
510
+ # next iteration we want all messages
511
+ start_index = 0
512
+
513
+
514
+ async def conversation_answer(
515
+ field: Conversation,
516
+ *,
517
+ start_from: tuple[int, int] = (1, 0), # (page, message)
518
+ ) -> tuple[int, int, resources_pb2.Message] | None:
519
+ """Find the next conversation message of type ANSWER starting from an
520
+ specific page and index.
521
+
522
+ """
523
+ async for page, index, message in iter_conversation_messages(field, start_from=start_from):
524
+ if message.type == resources_pb2.Message.MessageType.ANSWER:
525
+ return page, index, message
526
+ return None
527
+
528
+
529
+ async def conversation_messages_after(
530
+ field: Conversation,
531
+ *,
532
+ start_from: tuple[int, int] = (1, 0), # (page, index)
533
+ limit: int | None = None,
534
+ ) -> AsyncIterator[tuple[int, int, resources_pb2.Message]]:
535
+ assert limit is None or limit > 0, "this function can't iterate backwards"
536
+ async for page, index, message in iter_conversation_messages(field, start_from=start_from):
537
+ yield page, index, message
538
+
539
+ if limit is not None:
540
+ limit -= 1
541
+ if limit == 0:
542
+ break
543
+
544
+
545
+ async def conversation_selector(
546
+ field: Conversation,
547
+ field_id: FieldId,
548
+ selector: ConversationSelector,
549
+ ) -> AsyncIterator[tuple[int, int, resources_pb2.Message]]:
550
+ """Given a conversation, iterate through the messages matched by a
551
+ selector.
552
+
553
+ """
554
+ split = field_id.subfield_id
555
+
556
+ if isinstance(selector, MessageSelector):
557
+ if selector.id is None and selector.index is None and split is None:
558
+ return
559
+
560
+ if selector.index is not None:
561
+ metadata = await field.get_metadata()
562
+ if metadata is None:
563
+ # we can't know about pages/messages
564
+ return
565
+
566
+ if isinstance(selector.index, int):
567
+ page = selector.index // metadata.size + 1
568
+ index = selector.index % metadata.size
569
+
570
+ elif isinstance(selector.index, str):
571
+ if selector.index == "first":
572
+ page, index = (1, 0)
573
+ elif selector.index == "last":
574
+ page = metadata.pages
575
+ index = metadata.total % metadata.size - 1
576
+ else: # pragma: no cover
577
+ assert_never(selector.index)
578
+
579
+ else: # pragma: no cover
580
+ assert_never(selector.index)
581
+
582
+ found = None
583
+ async for found in iter_conversation_messages(field, start_from=(page, index)):
584
+ break
585
+
586
+ if found is None:
587
+ return
588
+
589
+ page, index, message = found
590
+ yield page, index, message
591
+
592
+ else:
593
+ # selector.id takes priority over the field id, as it is more specific
594
+ if selector.id is not None:
595
+ split = selector.id
596
+ assert split is not None
597
+
598
+ found = await find_conversation_message(field, split)
599
+ if found is None:
600
+ return
601
+
602
+ page, index, message = found
603
+ yield page, index, message
604
+
605
+ elif isinstance(selector, PageSelector):
606
+ if split is None:
607
+ return
608
+ found = await find_conversation_message(field, split)
609
+ if found is None:
610
+ return
611
+ page, _, _ = found
612
+
613
+ conversation_page = await field.db_get_value(page)
614
+ for index, message in enumerate(conversation_page.messages):
615
+ yield page, index, message
616
+
617
+ elif isinstance(selector, NeighboursSelector):
618
+ selector = cast(NeighboursSelector, selector)
619
+ if split is None:
620
+ return
621
+ found = await find_conversation_message(field, split)
622
+ if found is None:
623
+ return
624
+ page, index, message = found
625
+ yield page, index, message
626
+
627
+ start_from = (page, index + 1)
628
+ async for page, index, message in conversation_messages_after(
629
+ field, start_from=start_from, limit=selector.after
630
+ ):
631
+ yield page, index, message
632
+
633
+ elif isinstance(selector, WindowSelector):
634
+ if split is None:
635
+ return
636
+ # Find the position of the `split` message and get the window
637
+ # surrounding it. If there are not enough preceding/following messages,
638
+ # the window won't be centered
639
+ messages: Deque[tuple[int, int, resources_pb2.Message]] = deque(maxlen=selector.size)
640
+ metadata = await field.get_metadata()
641
+ pending = -1
642
+ for page in range(1, metadata.pages + 1):
643
+ conversation_page = await field.db_get_value(page)
644
+ for index, message in enumerate(conversation_page.messages):
645
+ messages.append((page, index, message))
646
+ if pending > 0:
647
+ pending -= 1
648
+ if message.ident == split:
649
+ pending = (selector.size - 1) // 2
650
+ if pending == 0:
651
+ break
652
+ if pending == 0:
653
+ break
654
+
655
+ for page, index, message in messages:
656
+ yield page, index, message
657
+
658
+ elif isinstance(selector, AnswerSelector):
659
+ if split is None:
660
+ return
661
+ found = await find_conversation_message(field, split)
662
+ if found is None:
663
+ return
664
+ page, index, message = found
665
+
666
+ found = await conversation_answer(field, start_from=(page, index))
667
+ if found is not None:
668
+ page, index, answer = found
669
+ yield page, index, answer
670
+
671
+ elif isinstance(selector, FullSelector):
672
+ async for page, index, message in iter_conversation_messages(field):
673
+ yield page, index, message
674
+
675
+ else: # pragma: no cover
676
+ assert_never(selector)
677
+
678
+
679
+ async def conversation_answer_or_after(
680
+ field: Conversation, field_id: FieldId
681
+ ) -> AsyncIterator[tuple[int, int, resources_pb2.Message]]:
682
+ m: resources_pb2.Message | None = None
683
+ # first search the message in the conversation
684
+ async for page, index, m in conversation_selector(field, field_id, MessageSelector()):
685
+ pass
686
+
687
+ if m is None:
688
+ return
689
+
690
+ if m.type == resources_pb2.Message.MessageType.QUESTION:
691
+ # try to find an answer for this question
692
+ found = await conversation_answer(field, start_from=(page, index + 1))
693
+ if found is None:
694
+ return
695
+ else:
696
+ page, index, answer = found
697
+ yield page, index, answer
698
+
699
+ else:
700
+ # add a bunch of messages after this for more context
701
+ async for page, index, message in conversation_messages_after(
702
+ field, start_from=(page, index + 1), limit=CONVERSATION_MESSAGE_CONTEXT_EXPANSION
703
+ ):
704
+ yield page, index, message