nucliadb 6.7.2.post4874__py3-none-any.whl → 6.10.0.post5705__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (246) hide show
  1. migrations/0023_backfill_pg_catalog.py +8 -4
  2. migrations/0028_extracted_vectors_reference.py +1 -1
  3. migrations/0029_backfill_field_status.py +3 -4
  4. migrations/0032_remove_old_relations.py +2 -3
  5. migrations/0038_backfill_catalog_field_labels.py +8 -4
  6. migrations/0039_backfill_converation_splits_metadata.py +106 -0
  7. migrations/0040_migrate_search_configurations.py +79 -0
  8. migrations/0041_reindex_conversations.py +137 -0
  9. migrations/pg/0010_shards_index.py +34 -0
  10. nucliadb/search/api/v1/resource/utils.py → migrations/pg/0011_catalog_statistics.py +5 -6
  11. migrations/pg/0012_catalog_statistics_undo.py +26 -0
  12. nucliadb/backups/create.py +2 -15
  13. nucliadb/backups/restore.py +4 -15
  14. nucliadb/backups/tasks.py +4 -1
  15. nucliadb/common/back_pressure/cache.py +2 -3
  16. nucliadb/common/back_pressure/materializer.py +7 -13
  17. nucliadb/common/back_pressure/settings.py +6 -6
  18. nucliadb/common/back_pressure/utils.py +1 -0
  19. nucliadb/common/cache.py +9 -9
  20. nucliadb/common/catalog/__init__.py +79 -0
  21. nucliadb/common/catalog/dummy.py +36 -0
  22. nucliadb/common/catalog/interface.py +85 -0
  23. nucliadb/{search/search/pgcatalog.py → common/catalog/pg.py} +330 -232
  24. nucliadb/common/catalog/utils.py +56 -0
  25. nucliadb/common/cluster/manager.py +8 -23
  26. nucliadb/common/cluster/rebalance.py +484 -112
  27. nucliadb/common/cluster/rollover.py +36 -9
  28. nucliadb/common/cluster/settings.py +4 -9
  29. nucliadb/common/cluster/utils.py +34 -8
  30. nucliadb/common/context/__init__.py +7 -8
  31. nucliadb/common/context/fastapi.py +1 -2
  32. nucliadb/common/datamanagers/__init__.py +2 -4
  33. nucliadb/common/datamanagers/atomic.py +9 -2
  34. nucliadb/common/datamanagers/cluster.py +1 -2
  35. nucliadb/common/datamanagers/fields.py +3 -4
  36. nucliadb/common/datamanagers/kb.py +6 -6
  37. nucliadb/common/datamanagers/labels.py +2 -3
  38. nucliadb/common/datamanagers/resources.py +10 -33
  39. nucliadb/common/datamanagers/rollover.py +5 -7
  40. nucliadb/common/datamanagers/search_configurations.py +1 -2
  41. nucliadb/common/datamanagers/synonyms.py +1 -2
  42. nucliadb/common/datamanagers/utils.py +4 -4
  43. nucliadb/common/datamanagers/vectorsets.py +4 -4
  44. nucliadb/common/external_index_providers/base.py +32 -5
  45. nucliadb/common/external_index_providers/manager.py +5 -34
  46. nucliadb/common/external_index_providers/settings.py +1 -27
  47. nucliadb/common/filter_expression.py +129 -41
  48. nucliadb/common/http_clients/exceptions.py +8 -0
  49. nucliadb/common/http_clients/processing.py +16 -23
  50. nucliadb/common/http_clients/utils.py +3 -0
  51. nucliadb/common/ids.py +82 -58
  52. nucliadb/common/locking.py +1 -2
  53. nucliadb/common/maindb/driver.py +9 -8
  54. nucliadb/common/maindb/local.py +5 -5
  55. nucliadb/common/maindb/pg.py +9 -8
  56. nucliadb/common/nidx.py +22 -5
  57. nucliadb/common/vector_index_config.py +1 -1
  58. nucliadb/export_import/datamanager.py +4 -3
  59. nucliadb/export_import/exporter.py +11 -19
  60. nucliadb/export_import/importer.py +13 -6
  61. nucliadb/export_import/tasks.py +2 -0
  62. nucliadb/export_import/utils.py +6 -18
  63. nucliadb/health.py +2 -2
  64. nucliadb/ingest/app.py +8 -8
  65. nucliadb/ingest/consumer/consumer.py +8 -10
  66. nucliadb/ingest/consumer/pull.py +10 -8
  67. nucliadb/ingest/consumer/service.py +5 -30
  68. nucliadb/ingest/consumer/shard_creator.py +16 -5
  69. nucliadb/ingest/consumer/utils.py +1 -1
  70. nucliadb/ingest/fields/base.py +37 -49
  71. nucliadb/ingest/fields/conversation.py +55 -9
  72. nucliadb/ingest/fields/exceptions.py +1 -2
  73. nucliadb/ingest/fields/file.py +22 -8
  74. nucliadb/ingest/fields/link.py +7 -7
  75. nucliadb/ingest/fields/text.py +2 -3
  76. nucliadb/ingest/orm/brain_v2.py +89 -57
  77. nucliadb/ingest/orm/broker_message.py +2 -4
  78. nucliadb/ingest/orm/entities.py +10 -209
  79. nucliadb/ingest/orm/index_message.py +128 -113
  80. nucliadb/ingest/orm/knowledgebox.py +91 -59
  81. nucliadb/ingest/orm/processor/auditing.py +1 -3
  82. nucliadb/ingest/orm/processor/data_augmentation.py +1 -2
  83. nucliadb/ingest/orm/processor/processor.py +98 -153
  84. nucliadb/ingest/orm/processor/sequence_manager.py +1 -2
  85. nucliadb/ingest/orm/resource.py +82 -71
  86. nucliadb/ingest/orm/utils.py +1 -1
  87. nucliadb/ingest/partitions.py +12 -1
  88. nucliadb/ingest/processing.py +17 -17
  89. nucliadb/ingest/serialize.py +202 -145
  90. nucliadb/ingest/service/writer.py +15 -114
  91. nucliadb/ingest/settings.py +36 -15
  92. nucliadb/ingest/utils.py +1 -2
  93. nucliadb/learning_proxy.py +23 -26
  94. nucliadb/metrics_exporter.py +20 -6
  95. nucliadb/middleware/__init__.py +82 -1
  96. nucliadb/migrator/datamanager.py +4 -11
  97. nucliadb/migrator/migrator.py +1 -2
  98. nucliadb/migrator/models.py +1 -2
  99. nucliadb/migrator/settings.py +1 -2
  100. nucliadb/models/internal/augment.py +614 -0
  101. nucliadb/models/internal/processing.py +19 -19
  102. nucliadb/openapi.py +2 -2
  103. nucliadb/purge/__init__.py +3 -8
  104. nucliadb/purge/orphan_shards.py +1 -2
  105. nucliadb/reader/__init__.py +5 -0
  106. nucliadb/reader/api/models.py +6 -13
  107. nucliadb/reader/api/v1/download.py +59 -38
  108. nucliadb/reader/api/v1/export_import.py +4 -4
  109. nucliadb/reader/api/v1/knowledgebox.py +37 -9
  110. nucliadb/reader/api/v1/learning_config.py +33 -14
  111. nucliadb/reader/api/v1/resource.py +61 -9
  112. nucliadb/reader/api/v1/services.py +18 -14
  113. nucliadb/reader/app.py +3 -1
  114. nucliadb/reader/reader/notifications.py +1 -2
  115. nucliadb/search/api/v1/__init__.py +3 -0
  116. nucliadb/search/api/v1/ask.py +3 -4
  117. nucliadb/search/api/v1/augment.py +585 -0
  118. nucliadb/search/api/v1/catalog.py +15 -19
  119. nucliadb/search/api/v1/find.py +16 -22
  120. nucliadb/search/api/v1/hydrate.py +328 -0
  121. nucliadb/search/api/v1/knowledgebox.py +1 -2
  122. nucliadb/search/api/v1/predict_proxy.py +1 -2
  123. nucliadb/search/api/v1/resource/ask.py +28 -8
  124. nucliadb/search/api/v1/resource/ingestion_agents.py +5 -6
  125. nucliadb/search/api/v1/resource/search.py +9 -11
  126. nucliadb/search/api/v1/retrieve.py +130 -0
  127. nucliadb/search/api/v1/search.py +28 -32
  128. nucliadb/search/api/v1/suggest.py +11 -14
  129. nucliadb/search/api/v1/summarize.py +1 -2
  130. nucliadb/search/api/v1/utils.py +2 -2
  131. nucliadb/search/app.py +3 -2
  132. nucliadb/search/augmentor/__init__.py +21 -0
  133. nucliadb/search/augmentor/augmentor.py +232 -0
  134. nucliadb/search/augmentor/fields.py +704 -0
  135. nucliadb/search/augmentor/metrics.py +24 -0
  136. nucliadb/search/augmentor/paragraphs.py +334 -0
  137. nucliadb/search/augmentor/resources.py +238 -0
  138. nucliadb/search/augmentor/utils.py +33 -0
  139. nucliadb/search/lifecycle.py +3 -1
  140. nucliadb/search/predict.py +33 -19
  141. nucliadb/search/predict_models.py +8 -9
  142. nucliadb/search/requesters/utils.py +11 -10
  143. nucliadb/search/search/cache.py +19 -42
  144. nucliadb/search/search/chat/ask.py +131 -59
  145. nucliadb/search/search/chat/exceptions.py +3 -5
  146. nucliadb/search/search/chat/fetcher.py +201 -0
  147. nucliadb/search/search/chat/images.py +6 -4
  148. nucliadb/search/search/chat/old_prompt.py +1375 -0
  149. nucliadb/search/search/chat/parser.py +510 -0
  150. nucliadb/search/search/chat/prompt.py +563 -615
  151. nucliadb/search/search/chat/query.py +453 -32
  152. nucliadb/search/search/chat/rpc.py +85 -0
  153. nucliadb/search/search/fetch.py +3 -4
  154. nucliadb/search/search/filters.py +8 -11
  155. nucliadb/search/search/find.py +33 -31
  156. nucliadb/search/search/find_merge.py +124 -331
  157. nucliadb/search/search/graph_strategy.py +14 -12
  158. nucliadb/search/search/hydrator/__init__.py +49 -0
  159. nucliadb/search/search/hydrator/fields.py +217 -0
  160. nucliadb/search/search/hydrator/images.py +130 -0
  161. nucliadb/search/search/hydrator/paragraphs.py +323 -0
  162. nucliadb/search/search/hydrator/resources.py +60 -0
  163. nucliadb/search/search/ingestion_agents.py +5 -5
  164. nucliadb/search/search/merge.py +90 -94
  165. nucliadb/search/search/metrics.py +24 -7
  166. nucliadb/search/search/paragraphs.py +7 -9
  167. nucliadb/search/search/predict_proxy.py +44 -18
  168. nucliadb/search/search/query.py +14 -86
  169. nucliadb/search/search/query_parser/fetcher.py +51 -82
  170. nucliadb/search/search/query_parser/models.py +19 -48
  171. nucliadb/search/search/query_parser/old_filters.py +20 -19
  172. nucliadb/search/search/query_parser/parsers/ask.py +5 -6
  173. nucliadb/search/search/query_parser/parsers/catalog.py +7 -11
  174. nucliadb/search/search/query_parser/parsers/common.py +21 -13
  175. nucliadb/search/search/query_parser/parsers/find.py +6 -29
  176. nucliadb/search/search/query_parser/parsers/graph.py +18 -28
  177. nucliadb/search/search/query_parser/parsers/retrieve.py +207 -0
  178. nucliadb/search/search/query_parser/parsers/search.py +15 -56
  179. nucliadb/search/search/query_parser/parsers/unit_retrieval.py +8 -29
  180. nucliadb/search/search/rank_fusion.py +18 -13
  181. nucliadb/search/search/rerankers.py +6 -7
  182. nucliadb/search/search/retrieval.py +300 -0
  183. nucliadb/search/search/summarize.py +5 -6
  184. nucliadb/search/search/utils.py +3 -4
  185. nucliadb/search/settings.py +1 -2
  186. nucliadb/standalone/api_router.py +1 -1
  187. nucliadb/standalone/app.py +4 -3
  188. nucliadb/standalone/auth.py +5 -6
  189. nucliadb/standalone/lifecycle.py +2 -2
  190. nucliadb/standalone/run.py +5 -4
  191. nucliadb/standalone/settings.py +5 -6
  192. nucliadb/standalone/versions.py +3 -4
  193. nucliadb/tasks/consumer.py +13 -8
  194. nucliadb/tasks/models.py +2 -1
  195. nucliadb/tasks/producer.py +3 -3
  196. nucliadb/tasks/retries.py +8 -7
  197. nucliadb/train/api/utils.py +1 -3
  198. nucliadb/train/api/v1/shards.py +1 -2
  199. nucliadb/train/api/v1/trainset.py +1 -2
  200. nucliadb/train/app.py +1 -1
  201. nucliadb/train/generator.py +4 -4
  202. nucliadb/train/generators/field_classifier.py +2 -2
  203. nucliadb/train/generators/field_streaming.py +6 -6
  204. nucliadb/train/generators/image_classifier.py +2 -2
  205. nucliadb/train/generators/paragraph_classifier.py +2 -2
  206. nucliadb/train/generators/paragraph_streaming.py +2 -2
  207. nucliadb/train/generators/question_answer_streaming.py +2 -2
  208. nucliadb/train/generators/sentence_classifier.py +4 -10
  209. nucliadb/train/generators/token_classifier.py +3 -2
  210. nucliadb/train/generators/utils.py +6 -5
  211. nucliadb/train/nodes.py +3 -3
  212. nucliadb/train/resource.py +6 -8
  213. nucliadb/train/settings.py +3 -4
  214. nucliadb/train/types.py +11 -11
  215. nucliadb/train/upload.py +3 -2
  216. nucliadb/train/uploader.py +1 -2
  217. nucliadb/train/utils.py +1 -2
  218. nucliadb/writer/api/v1/export_import.py +4 -1
  219. nucliadb/writer/api/v1/field.py +15 -14
  220. nucliadb/writer/api/v1/knowledgebox.py +18 -56
  221. nucliadb/writer/api/v1/learning_config.py +5 -4
  222. nucliadb/writer/api/v1/resource.py +9 -20
  223. nucliadb/writer/api/v1/services.py +10 -132
  224. nucliadb/writer/api/v1/upload.py +73 -72
  225. nucliadb/writer/app.py +8 -2
  226. nucliadb/writer/resource/basic.py +12 -15
  227. nucliadb/writer/resource/field.py +43 -5
  228. nucliadb/writer/resource/origin.py +7 -0
  229. nucliadb/writer/settings.py +2 -3
  230. nucliadb/writer/tus/__init__.py +2 -3
  231. nucliadb/writer/tus/azure.py +5 -7
  232. nucliadb/writer/tus/dm.py +3 -3
  233. nucliadb/writer/tus/exceptions.py +3 -4
  234. nucliadb/writer/tus/gcs.py +15 -22
  235. nucliadb/writer/tus/s3.py +2 -3
  236. nucliadb/writer/tus/storage.py +3 -3
  237. {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/METADATA +10 -11
  238. nucliadb-6.10.0.post5705.dist-info/RECORD +410 -0
  239. nucliadb/common/datamanagers/entities.py +0 -139
  240. nucliadb/common/external_index_providers/pinecone.py +0 -894
  241. nucliadb/ingest/orm/processor/pgcatalog.py +0 -129
  242. nucliadb/search/search/hydrator.py +0 -197
  243. nucliadb-6.7.2.post4874.dist-info/RECORD +0 -383
  244. {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/WHEEL +0 -0
  245. {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/entry_points.txt +0 -0
  246. {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1375 @@
1
+ # Copyright (C) 2021 Bosutech XXI S.L.
2
+ #
3
+ # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
+ # For commercial licensing, contact us at info@nuclia.com.
5
+ #
6
+ # AGPL:
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU Affero General Public License as
9
+ # published by the Free Software Foundation, either version 3 of the
10
+ # License, or (at your option) any later version.
11
+ #
12
+ # This program is distributed in the hope that it will be useful,
13
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
+ # GNU Affero General Public License for more details.
16
+ #
17
+ # You should have received a copy of the GNU Affero General Public License
18
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
+ #
20
+ import asyncio
21
+ import copy
22
+ from collections import deque
23
+ from collections.abc import Sequence
24
+ from dataclasses import dataclass
25
+ from typing import Deque, cast
26
+
27
+ import yaml
28
+ from pydantic import BaseModel
29
+
30
+ from nucliadb.common import datamanagers
31
+ from nucliadb.common.ids import FIELD_TYPE_PB_TO_STR, FIELD_TYPE_STR_TO_PB, FieldId, ParagraphId
32
+ from nucliadb.common.maindb.utils import get_driver
33
+ from nucliadb.common.models_utils import from_proto
34
+ from nucliadb.ingest.fields.base import Field
35
+ from nucliadb.ingest.fields.conversation import Conversation
36
+ from nucliadb.ingest.fields.file import File
37
+ from nucliadb.ingest.orm.knowledgebox import KnowledgeBox as KnowledgeBoxORM
38
+ from nucliadb.search import logger
39
+ from nucliadb.search.search import cache
40
+ from nucliadb.search.search.chat.images import (
41
+ get_file_thumbnail_image,
42
+ get_page_image,
43
+ get_paragraph_image,
44
+ )
45
+ from nucliadb.search.search.metrics import Metrics
46
+ from nucliadb.search.search.paragraphs import get_paragraph_text
47
+ from nucliadb_models.labels import translate_alias_to_system_label
48
+ from nucliadb_models.metadata import Extra, Origin
49
+ from nucliadb_models.search import (
50
+ SCORE_TYPE,
51
+ AugmentedContext,
52
+ AugmentedTextBlock,
53
+ ConversationalStrategy,
54
+ FieldExtensionStrategy,
55
+ FindParagraph,
56
+ FullResourceStrategy,
57
+ HierarchyResourceStrategy,
58
+ Image,
59
+ ImageRagStrategy,
60
+ ImageRagStrategyName,
61
+ MetadataExtensionStrategy,
62
+ MetadataExtensionType,
63
+ NeighbouringParagraphsStrategy,
64
+ PageImageStrategy,
65
+ ParagraphImageStrategy,
66
+ PromptContext,
67
+ PromptContextImages,
68
+ PromptContextOrder,
69
+ RagStrategy,
70
+ RagStrategyName,
71
+ TableImageStrategy,
72
+ TextBlockAugmentationType,
73
+ TextPosition,
74
+ )
75
+ from nucliadb_protos import resources_pb2
76
+ from nucliadb_protos.resources_pb2 import ExtractedText, FieldComputedMetadata
77
+ from nucliadb_telemetry.metrics import Observer
78
+ from nucliadb_utils.asyncio_utils import ConcurrentRunner, run_concurrently
79
+ from nucliadb_utils.utilities import get_storage
80
+
81
+ MAX_RESOURCE_TASKS = 5
82
+ MAX_RESOURCE_FIELD_TASKS = 4
83
+
84
+
85
+ # Number of messages to pull after a match in a message
86
+ # The hope here is it will be enough to get the answer to the question.
87
+ CONVERSATION_MESSAGE_CONTEXT_EXPANSION = 15
88
+
89
+ TextBlockId = ParagraphId | FieldId
90
+
91
+
92
+ class ParagraphIdNotFoundInExtractedMetadata(Exception):
93
+ pass
94
+
95
+
96
+ class CappedPromptContext:
97
+ """
98
+ Class to keep track of the size (in number of characters) of the prompt context
99
+ and automatically trim data that exceeds the limit when it's being set on the dictionary.
100
+ """
101
+
102
+ def __init__(self, max_size: int | None):
103
+ self.output: PromptContext = {}
104
+ self.images: PromptContextImages = {}
105
+ self.max_size = max_size
106
+
107
+ def __setitem__(self, key: str, value: str) -> None:
108
+ self.output.__setitem__(key, value)
109
+
110
+ def __getitem__(self, key: str) -> str:
111
+ return self.output.__getitem__(key)
112
+
113
+ def __contains__(self, key: str) -> bool:
114
+ return key in self.output
115
+
116
+ def __delitem__(self, key: str) -> None:
117
+ try:
118
+ self.output.__delitem__(key)
119
+ except KeyError:
120
+ pass
121
+
122
+ def text_block_ids(self) -> list[str]:
123
+ return list(self.output.keys())
124
+
125
+ @property
126
+ def size(self) -> int:
127
+ """
128
+ Returns the total size of the context in characters.
129
+ """
130
+ return sum(len(text) for text in self.output.values())
131
+
132
+ def cap(self) -> dict[str, str]:
133
+ """
134
+ This method will trim the context to the maximum size if it exceeds it.
135
+ It will remove text from the most recent entries first, until the size is below the limit.
136
+ """
137
+ if self.max_size is None:
138
+ return self.output
139
+
140
+ if self.size <= self.max_size:
141
+ return self.output
142
+
143
+ logger.info("Removing text from context to fit within the max size limit")
144
+ # Iterate the dictionary in reverse order of insertion
145
+ for key in reversed(list(self.output.keys())):
146
+ current_size = self.size
147
+ if current_size <= self.max_size:
148
+ break
149
+ # Remove text from the value
150
+ text = self.output[key]
151
+ # If removing the whole text still keeps the total size above the limit, remove it
152
+ if current_size - len(text) >= self.max_size:
153
+ del self.output[key]
154
+ else:
155
+ # Otherwise, trim the text to fit within the limit
156
+ excess_size = current_size - self.max_size
157
+ if excess_size > 0:
158
+ trimmed_text = text[:-excess_size]
159
+ self.output[key] = trimmed_text
160
+ return self.output
161
+
162
+
163
+ async def get_next_conversation_messages(
164
+ *,
165
+ field_obj: Conversation,
166
+ page: int,
167
+ start_idx: int,
168
+ num_messages: int,
169
+ message_type: resources_pb2.Message.MessageType.ValueType | None = None,
170
+ msg_to: str | None = None,
171
+ ) -> list[resources_pb2.Message]:
172
+ output = []
173
+ cmetadata = await field_obj.get_metadata()
174
+ for current_page in range(page, cmetadata.pages + 1):
175
+ conv = await field_obj.db_get_value(current_page)
176
+ for message in conv.messages[start_idx:]:
177
+ if message_type is not None and message.type != message_type: # pragma: no cover
178
+ continue
179
+ if msg_to is not None and msg_to not in message.to: # pragma: no cover
180
+ continue
181
+ output.append(message)
182
+ if len(output) >= num_messages:
183
+ return output
184
+ start_idx = 0
185
+
186
+ return output
187
+
188
+
189
+ async def find_conversation_message(
190
+ field_obj: Conversation, mident: str
191
+ ) -> tuple[resources_pb2.Message | None, int, int]:
192
+ cmetadata = await field_obj.get_metadata()
193
+ for page in range(1, cmetadata.pages + 1):
194
+ conv = await field_obj.db_get_value(page)
195
+ for idx, message in enumerate(conv.messages):
196
+ if message.ident == mident:
197
+ return message, page, idx
198
+ return None, -1, -1
199
+
200
+
201
+ async def get_expanded_conversation_messages(
202
+ *,
203
+ kb: KnowledgeBoxORM,
204
+ rid: str,
205
+ field_id: str,
206
+ mident: str,
207
+ max_messages: int = CONVERSATION_MESSAGE_CONTEXT_EXPANSION,
208
+ ) -> list[resources_pb2.Message]:
209
+ resource = await kb.get(rid)
210
+ if resource is None: # pragma: no cover
211
+ return []
212
+ field_obj: Conversation = await resource.get_field(field_id, FIELD_TYPE_STR_TO_PB["c"], load=True) # type: ignore
213
+ found_message, found_page, found_idx = await find_conversation_message(
214
+ field_obj=field_obj, mident=mident
215
+ )
216
+ if found_message is None: # pragma: no cover
217
+ return []
218
+ elif found_message.type == resources_pb2.Message.MessageType.QUESTION:
219
+ # only try to get answer if it was a question
220
+ return await get_next_conversation_messages(
221
+ field_obj=field_obj,
222
+ page=found_page,
223
+ start_idx=found_idx + 1,
224
+ num_messages=1,
225
+ message_type=resources_pb2.Message.MessageType.ANSWER,
226
+ )
227
+ else:
228
+ return await get_next_conversation_messages(
229
+ field_obj=field_obj,
230
+ page=found_page,
231
+ start_idx=found_idx + 1,
232
+ num_messages=max_messages,
233
+ )
234
+
235
+
236
+ async def default_prompt_context(
237
+ context: CappedPromptContext,
238
+ kbid: str,
239
+ ordered_paragraphs: list[FindParagraph],
240
+ ) -> None:
241
+ """
242
+ - Updates context (which is an ordered dict of text_block_id -> context_text).
243
+ - text_block_id is typically the paragraph id, but has a special value for the
244
+ user context. (USER_CONTEXT_0, USER_CONTEXT_1, ...)
245
+ - Paragraphs are inserted in order of relevance, by increasing `order` field
246
+ of the find result paragraphs.
247
+ - User context is inserted first, in order of appearance.
248
+ - Using an dict prevents from duplicates pulled in through conversation expansion.
249
+ """
250
+ # Sort retrieved paragraphs by decreasing order (most relevant first)
251
+ async with get_driver().ro_transaction() as txn:
252
+ storage = await get_storage()
253
+ kb = KnowledgeBoxORM(txn, storage, kbid)
254
+ for paragraph in ordered_paragraphs:
255
+ context[paragraph.id] = _clean_paragraph_text(paragraph)
256
+
257
+ # If the paragraph is a conversation and it matches semantically, we assume we
258
+ # have matched with the question, therefore try to include the answer to the
259
+ # context by pulling the next few messages of the conversation field
260
+ rid, field_type, field_id, mident = paragraph.id.split("/")[:4]
261
+ if field_type == "c" and paragraph.score_type in (
262
+ SCORE_TYPE.VECTOR,
263
+ SCORE_TYPE.BOTH,
264
+ ):
265
+ expanded_msgs = await get_expanded_conversation_messages(
266
+ kb=kb, rid=rid, field_id=field_id, mident=mident
267
+ )
268
+ for msg in expanded_msgs:
269
+ text = msg.content.text.strip()
270
+ pid = f"{rid}/{field_type}/{field_id}/{msg.ident}/0-{len(msg.content.text)}"
271
+ context[pid] = text
272
+
273
+
274
+ async def full_resource_prompt_context(
275
+ context: CappedPromptContext,
276
+ kbid: str,
277
+ ordered_paragraphs: list[FindParagraph],
278
+ resource: str | None,
279
+ strategy: FullResourceStrategy,
280
+ metrics: Metrics,
281
+ augmented_context: AugmentedContext,
282
+ ) -> None:
283
+ """
284
+ Algorithm steps:
285
+ - Collect the list of resources in the results (in order of relevance).
286
+ - For each resource, collect the extracted text from all its fields and craft the context.
287
+ Arguments:
288
+ context: The context to be updated.
289
+ kbid: The knowledge box id.
290
+ ordered_paragraphs: The results of the retrieval (find) operation.
291
+ resource: The resource to be included in the context. This is used only when chatting with a specific resource with no retrieval.
292
+ strategy: strategy instance containing, for example, the number of full resources to include in the context.
293
+ """
294
+ if resource is not None:
295
+ # The user has specified a resource to be included in the context.
296
+ ordered_resources = [resource]
297
+ else:
298
+ # Collect the list of resources in the results (in order of relevance).
299
+ ordered_resources = []
300
+ for paragraph in ordered_paragraphs:
301
+ resource_uuid = parse_text_block_id(paragraph.id).rid
302
+ if resource_uuid not in ordered_resources:
303
+ skip = False
304
+ if strategy.apply_to is not None:
305
+ # decide whether the resource should be extended or not
306
+ for label in strategy.apply_to.exclude:
307
+ skip = skip or (
308
+ translate_alias_to_system_label(label) in (paragraph.labels or [])
309
+ )
310
+
311
+ if not skip:
312
+ ordered_resources.append(resource_uuid)
313
+
314
+ # For each resource, collect the extracted text from all its fields.
315
+ resources_extracted_texts = await run_concurrently(
316
+ [
317
+ hydrate_resource_text(kbid, resource_uuid, max_concurrent_tasks=MAX_RESOURCE_FIELD_TASKS)
318
+ for resource_uuid in ordered_resources[: strategy.count]
319
+ ],
320
+ max_concurrent=MAX_RESOURCE_TASKS,
321
+ )
322
+ added_fields = set()
323
+ for resource_extracted_texts in resources_extracted_texts:
324
+ if resource_extracted_texts is None:
325
+ continue
326
+ for field, extracted_text in resource_extracted_texts:
327
+ # First off, remove the text block ids from paragraphs that belong to
328
+ # the same field, as otherwise the context will be duplicated.
329
+ for tb_id in context.text_block_ids():
330
+ if tb_id.startswith(field.full()):
331
+ del context[tb_id]
332
+ # Add the extracted text of each field to the context.
333
+ context[field.full()] = extracted_text
334
+ augmented_context.fields[field.full()] = AugmentedTextBlock(
335
+ id=field.full(),
336
+ text=extracted_text,
337
+ augmentation_type=TextBlockAugmentationType.FULL_RESOURCE,
338
+ )
339
+
340
+ added_fields.add(field.full())
341
+
342
+ metrics.set("full_resource_ops", len(added_fields))
343
+
344
+ if strategy.include_remaining_text_blocks:
345
+ for paragraph in ordered_paragraphs:
346
+ pid = cast(ParagraphId, parse_text_block_id(paragraph.id))
347
+ if pid.field_id.full() not in added_fields:
348
+ context[paragraph.id] = _clean_paragraph_text(paragraph)
349
+
350
+
351
+ async def extend_prompt_context_with_metadata(
352
+ context: CappedPromptContext,
353
+ kbid: str,
354
+ strategy: MetadataExtensionStrategy,
355
+ metrics: Metrics,
356
+ augmented_context: AugmentedContext,
357
+ ) -> None:
358
+ text_block_ids: list[TextBlockId] = []
359
+ for text_block_id in context.text_block_ids():
360
+ try:
361
+ text_block_ids.append(parse_text_block_id(text_block_id))
362
+ except ValueError: # pragma: no cover
363
+ # Some text block ids are not paragraphs nor fields, so they are skipped
364
+ # (e.g. USER_CONTEXT_0, when the user provides extra context)
365
+ continue
366
+ if len(text_block_ids) == 0: # pragma: no cover
367
+ return
368
+
369
+ ops = 0
370
+ if MetadataExtensionType.ORIGIN in strategy.types:
371
+ ops += 1
372
+ await extend_prompt_context_with_origin_metadata(
373
+ context, kbid, text_block_ids, augmented_context
374
+ )
375
+
376
+ if MetadataExtensionType.CLASSIFICATION_LABELS in strategy.types:
377
+ ops += 1
378
+ await extend_prompt_context_with_classification_labels(
379
+ context, kbid, text_block_ids, augmented_context
380
+ )
381
+
382
+ if MetadataExtensionType.NERS in strategy.types:
383
+ ops += 1
384
+ await extend_prompt_context_with_ner(context, kbid, text_block_ids, augmented_context)
385
+
386
+ if MetadataExtensionType.EXTRA_METADATA in strategy.types:
387
+ ops += 1
388
+ await extend_prompt_context_with_extra_metadata(context, kbid, text_block_ids, augmented_context)
389
+
390
+ metrics.set("metadata_extension_ops", ops * len(text_block_ids))
391
+
392
+
393
+ def parse_text_block_id(text_block_id: str) -> TextBlockId:
394
+ try:
395
+ # Typically, the text block id is a paragraph id
396
+ return ParagraphId.from_string(text_block_id)
397
+ except ValueError:
398
+ # When we're doing `full_resource` or `hierarchy` strategies,the text block id
399
+ # is a field id
400
+ return FieldId.from_string(text_block_id)
401
+
402
+
403
+ async def extend_prompt_context_with_origin_metadata(
404
+ context: CappedPromptContext,
405
+ kbid,
406
+ text_block_ids: list[TextBlockId],
407
+ augmented_context: AugmentedContext,
408
+ ):
409
+ async def _get_origin(kbid: str, rid: str) -> tuple[str, Origin | None]:
410
+ origin = None
411
+ resource = await cache.get_resource(kbid, rid)
412
+ if resource is not None:
413
+ pb_origin = await resource.get_origin()
414
+ if pb_origin is not None:
415
+ origin = from_proto.origin(pb_origin)
416
+ return rid, origin
417
+
418
+ rids = {tb_id.rid for tb_id in text_block_ids}
419
+ origins = await run_concurrently([_get_origin(kbid, rid) for rid in rids])
420
+ rid_to_origin = {rid: origin for rid, origin in origins if origin is not None}
421
+ for tb_id in text_block_ids:
422
+ origin = rid_to_origin.get(tb_id.rid)
423
+ if origin is not None and tb_id.full() in context:
424
+ text = context.output.pop(tb_id.full())
425
+ extended_text = text + f"\n\nDOCUMENT METADATA AT ORIGIN:\n{to_yaml(origin)}"
426
+ context[tb_id.full()] = extended_text
427
+ augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
428
+ id=tb_id.full(),
429
+ text=extended_text,
430
+ parent=tb_id.full(),
431
+ augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
432
+ )
433
+
434
+
435
+ async def extend_prompt_context_with_classification_labels(
436
+ context: CappedPromptContext,
437
+ kbid: str,
438
+ text_block_ids: list[TextBlockId],
439
+ augmented_context: AugmentedContext,
440
+ ):
441
+ async def _get_labels(kbid: str, _id: TextBlockId) -> tuple[TextBlockId, list[tuple[str, str]]]:
442
+ fid = _id if isinstance(_id, FieldId) else _id.field_id
443
+ labels = set()
444
+ resource = await cache.get_resource(kbid, fid.rid)
445
+ if resource is not None:
446
+ pb_basic = await resource.get_basic()
447
+ if pb_basic is not None:
448
+ # Add the classification labels of the resource
449
+ for classif in pb_basic.usermetadata.classifications:
450
+ labels.add((classif.labelset, classif.label))
451
+ # Add the classifications labels of the field
452
+ for fc in pb_basic.computedmetadata.field_classifications:
453
+ if fc.field.field == fid.key and fc.field.field_type == fid.pb_type:
454
+ for classif in fc.classifications:
455
+ if classif.cancelled_by_user: # pragma: no cover
456
+ continue
457
+ labels.add((classif.labelset, classif.label))
458
+ return _id, list(labels)
459
+
460
+ classif_labels = await run_concurrently([_get_labels(kbid, tb_id) for tb_id in text_block_ids])
461
+ tb_id_to_labels = {tb_id: labels for tb_id, labels in classif_labels if len(labels) > 0}
462
+ for tb_id in text_block_ids:
463
+ labels = tb_id_to_labels.get(tb_id)
464
+ if labels is not None and tb_id.full() in context:
465
+ text = context.output.pop(tb_id.full())
466
+
467
+ labels_text = "DOCUMENT CLASSIFICATION LABELS:"
468
+ for labelset, label in labels:
469
+ labels_text += f"\n - {label} ({labelset})"
470
+ extended_text = text + "\n\n" + labels_text
471
+
472
+ context[tb_id.full()] = extended_text
473
+ augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
474
+ id=tb_id.full(),
475
+ text=extended_text,
476
+ parent=tb_id.full(),
477
+ augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
478
+ )
479
+
480
+
481
+ async def extend_prompt_context_with_ner(
482
+ context: CappedPromptContext,
483
+ kbid: str,
484
+ text_block_ids: list[TextBlockId],
485
+ augmented_context: AugmentedContext,
486
+ ):
487
+ async def _get_ners(kbid: str, _id: TextBlockId) -> tuple[TextBlockId, dict[str, set[str]]]:
488
+ fid = _id if isinstance(_id, FieldId) else _id.field_id
489
+ ners: dict[str, set[str]] = {}
490
+ resource = await cache.get_resource(kbid, fid.rid)
491
+ if resource is not None:
492
+ field = await resource.get_field(fid.key, fid.pb_type, load=False)
493
+ fcm = await field.get_field_metadata()
494
+ if fcm is not None:
495
+ # Data Augmentation + Processor entities
496
+ for (
497
+ data_aumgentation_task_id,
498
+ entities_wrapper,
499
+ ) in fcm.metadata.entities.items():
500
+ for entity in entities_wrapper.entities:
501
+ ners.setdefault(entity.label, set()).add(entity.text)
502
+ # Legacy processor entities
503
+ # TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message
504
+ for token, family in fcm.metadata.ner.items():
505
+ ners.setdefault(family, set()).add(token)
506
+ return _id, ners
507
+
508
+ nerss = await run_concurrently([_get_ners(kbid, tb_id) for tb_id in text_block_ids])
509
+ tb_id_to_ners = {tb_id: ners for tb_id, ners in nerss if len(ners) > 0}
510
+ for tb_id in text_block_ids:
511
+ ners = tb_id_to_ners.get(tb_id)
512
+ if ners is not None and tb_id.full() in context:
513
+ text = context.output.pop(tb_id.full())
514
+
515
+ ners_text = "DOCUMENT NAMED ENTITIES (NERs):"
516
+ for family, tokens in ners.items():
517
+ ners_text += f"\n - {family}:"
518
+ for token in sorted(list(tokens)):
519
+ ners_text += f"\n - {token}"
520
+
521
+ extended_text = text + "\n\n" + ners_text
522
+
523
+ context[tb_id.full()] = extended_text
524
+ augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
525
+ id=tb_id.full(),
526
+ text=extended_text,
527
+ parent=tb_id.full(),
528
+ augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
529
+ )
530
+
531
+
532
+ async def extend_prompt_context_with_extra_metadata(
533
+ context: CappedPromptContext,
534
+ kbid: str,
535
+ text_block_ids: list[TextBlockId],
536
+ augmented_context: AugmentedContext,
537
+ ):
538
+ async def _get_extra(kbid: str, rid: str) -> tuple[str, Extra | None]:
539
+ extra = None
540
+ resource = await cache.get_resource(kbid, rid)
541
+ if resource is not None:
542
+ pb_extra = await resource.get_extra()
543
+ if pb_extra is not None:
544
+ extra = from_proto.extra(pb_extra)
545
+ return rid, extra
546
+
547
+ rids = {tb_id.rid for tb_id in text_block_ids}
548
+ extras = await run_concurrently([_get_extra(kbid, rid) for rid in rids])
549
+ rid_to_extra = {rid: extra for rid, extra in extras if extra is not None}
550
+ for tb_id in text_block_ids:
551
+ extra = rid_to_extra.get(tb_id.rid)
552
+ if extra is not None and tb_id.full() in context:
553
+ text = context.output.pop(tb_id.full())
554
+ extended_text = text + f"\n\nDOCUMENT EXTRA METADATA:\n{to_yaml(extra)}"
555
+ context[tb_id.full()] = extended_text
556
+ augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
557
+ id=tb_id.full(),
558
+ text=extended_text,
559
+ parent=tb_id.full(),
560
+ augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
561
+ )
562
+
563
+
564
+ def to_yaml(obj: BaseModel) -> str:
565
+ return yaml.dump(
566
+ obj.model_dump(exclude_none=True, exclude_defaults=True, exclude_unset=True),
567
+ default_flow_style=False,
568
+ indent=2,
569
+ sort_keys=True,
570
+ )
571
+
572
+
573
+ async def field_extension_prompt_context(
574
+ context: CappedPromptContext,
575
+ kbid: str,
576
+ ordered_paragraphs: list[FindParagraph],
577
+ strategy: FieldExtensionStrategy,
578
+ metrics: Metrics,
579
+ augmented_context: AugmentedContext,
580
+ ) -> None:
581
+ """
582
+ Algorithm steps:
583
+ - Collect the list of resources in the results (in order of relevance).
584
+ - For each resource, collect the extracted text from all its fields.
585
+ - Add the extracted text of each field to the beginning of the context.
586
+ - Add the extracted text of each paragraph to the end of the context.
587
+ """
588
+ ordered_resources = []
589
+ for paragraph in ordered_paragraphs:
590
+ resource_uuid = ParagraphId.from_string(paragraph.id).rid
591
+ if resource_uuid not in ordered_resources:
592
+ ordered_resources.append(resource_uuid)
593
+
594
+ extend_field_ids = await get_matching_field_ids(kbid, ordered_resources, strategy)
595
+ tasks = [hydrate_field_text(kbid, fid) for fid in extend_field_ids]
596
+ field_extracted_texts = await run_concurrently(tasks)
597
+
598
+ metrics.set("field_extension_ops", len(field_extracted_texts))
599
+
600
+ for result in field_extracted_texts:
601
+ if result is None: # pragma: no cover
602
+ continue
603
+ field, extracted_text = result
604
+ # First off, remove the text block ids from paragraphs that belong to
605
+ # the same field, as otherwise the context will be duplicated.
606
+ for tb_id in context.text_block_ids():
607
+ if tb_id.startswith(field.full()):
608
+ del context[tb_id]
609
+ # Add the extracted text of each field to the beginning of the context.
610
+ if field.full() not in context:
611
+ context[field.full()] = extracted_text
612
+ augmented_context.fields[field.full()] = AugmentedTextBlock(
613
+ id=field.full(),
614
+ text=extracted_text,
615
+ augmentation_type=TextBlockAugmentationType.FIELD_EXTENSION,
616
+ )
617
+
618
+ # Add the extracted text of each paragraph to the end of the context.
619
+ for paragraph in ordered_paragraphs:
620
+ if paragraph.id not in context:
621
+ context[paragraph.id] = _clean_paragraph_text(paragraph)
622
+
623
+
624
+ async def get_matching_field_ids(
625
+ kbid: str, ordered_resources: list[str], strategy: FieldExtensionStrategy
626
+ ) -> list[FieldId]:
627
+ extend_field_ids: list[FieldId] = []
628
+ # Fetch the extracted texts of the specified fields for each resource
629
+ for resource_uuid in ordered_resources:
630
+ for field_id in strategy.fields:
631
+ try:
632
+ fid = FieldId.from_string(f"{resource_uuid}/{field_id.strip('/')}")
633
+ extend_field_ids.append(fid)
634
+ except ValueError: # pragma: no cover
635
+ # Invalid field id, skiping
636
+ continue
637
+ if len(strategy.data_augmentation_field_prefixes) > 0:
638
+ for resource_uuid in ordered_resources:
639
+ all_field_ids = await datamanagers.atomic.resources.get_all_field_ids(
640
+ kbid=kbid, rid=resource_uuid, for_update=False
641
+ )
642
+ if all_field_ids is None:
643
+ continue
644
+ for fieldid in all_field_ids.fields:
645
+ # Generated fields are always text fields starting with "da-"
646
+ if any(
647
+ (
648
+ fieldid.field_type == resources_pb2.FieldType.TEXT
649
+ and fieldid.field.startswith(f"da-{prefix}-")
650
+ )
651
+ for prefix in strategy.data_augmentation_field_prefixes
652
+ ):
653
+ extend_field_ids.append(
654
+ FieldId.from_pb(
655
+ rid=resource_uuid, field_type=fieldid.field_type, key=fieldid.field
656
+ )
657
+ )
658
+ return extend_field_ids
659
+
660
+
661
+ async def get_orm_field(kbid: str, field_id: FieldId) -> Field | None:
662
+ resource = await cache.get_resource(kbid, field_id.rid)
663
+ if resource is None: # pragma: no cover
664
+ return None
665
+ return await resource.get_field(key=field_id.key, type=field_id.pb_type, load=False)
666
+
667
+
668
+ async def neighbouring_paragraphs_prompt_context(
669
+ context: CappedPromptContext,
670
+ kbid: str,
671
+ ordered_text_blocks: list[FindParagraph],
672
+ strategy: NeighbouringParagraphsStrategy,
673
+ metrics: Metrics,
674
+ augmented_context: AugmentedContext,
675
+ ) -> None:
676
+ """
677
+ This function will get the paragraph texts and then craft a context with the neighbouring paragraphs of the
678
+ paragraphs in the ordered_paragraphs list.
679
+ """
680
+ retrieved_paragraphs_ids = [
681
+ ParagraphId.from_string(text_block.id) for text_block in ordered_text_blocks
682
+ ]
683
+ unique_field_ids = list({pid.field_id for pid in retrieved_paragraphs_ids})
684
+
685
+ # Get extracted texts and metadatas for all fields
686
+ fm_ops = []
687
+ et_ops = []
688
+ for field_id in unique_field_ids:
689
+ field = await get_orm_field(kbid, field_id)
690
+ if field is None:
691
+ continue
692
+ fm_ops.append(asyncio.create_task(field.get_field_metadata()))
693
+ et_ops.append(asyncio.create_task(field.get_extracted_text()))
694
+
695
+ field_metadatas: dict[FieldId, FieldComputedMetadata] = {
696
+ fid: fm for fid, fm in zip(unique_field_ids, await asyncio.gather(*fm_ops)) if fm is not None
697
+ }
698
+ extracted_texts: dict[FieldId, ExtractedText] = {
699
+ fid: et for fid, et in zip(unique_field_ids, await asyncio.gather(*et_ops)) if et is not None
700
+ }
701
+
702
+ def _get_paragraph_text(extracted_text: ExtractedText, pid: ParagraphId) -> str:
703
+ if pid.field_id.subfield_id:
704
+ text = extracted_text.split_text.get(pid.field_id.subfield_id) or ""
705
+ else:
706
+ text = extracted_text.text
707
+ return text[pid.paragraph_start : pid.paragraph_end]
708
+
709
+ for pid in retrieved_paragraphs_ids:
710
+ # Add the retrieved paragraph first
711
+ field_extracted_text = extracted_texts.get(pid.field_id, None)
712
+ if field_extracted_text is None:
713
+ continue
714
+ ptext = _get_paragraph_text(field_extracted_text, pid)
715
+ if ptext and pid.full() not in context:
716
+ context[pid.full()] = ptext
717
+
718
+ # Now add the neighbouring paragraphs
719
+ field_extracted_metadata = field_metadatas.get(pid.field_id, None)
720
+ if field_extracted_metadata is None:
721
+ continue
722
+
723
+ field_pids = [
724
+ ParagraphId(
725
+ field_id=pid.field_id,
726
+ paragraph_start=p.start,
727
+ paragraph_end=p.end,
728
+ )
729
+ for p in field_extracted_metadata.metadata.paragraphs
730
+ ]
731
+ try:
732
+ index = field_pids.index(pid)
733
+ except ValueError:
734
+ continue
735
+
736
+ for neighbour_index in get_neighbouring_indices(
737
+ index=index,
738
+ before=strategy.before,
739
+ after=strategy.after,
740
+ field_pids=field_pids,
741
+ ):
742
+ if neighbour_index == index:
743
+ # Already handled above
744
+ continue
745
+ try:
746
+ npid = field_pids[neighbour_index]
747
+ except IndexError:
748
+ continue
749
+ if npid in retrieved_paragraphs_ids or npid.full() in context:
750
+ # Already added
751
+ continue
752
+ ptext = _get_paragraph_text(field_extracted_text, npid)
753
+ if not ptext:
754
+ continue
755
+ context[npid.full()] = ptext
756
+ augmented_context.paragraphs[npid.full()] = AugmentedTextBlock(
757
+ id=npid.full(),
758
+ text=ptext,
759
+ position=get_text_position(npid, neighbour_index, field_extracted_metadata),
760
+ parent=pid.full(),
761
+ augmentation_type=TextBlockAugmentationType.NEIGHBOURING_PARAGRAPHS,
762
+ )
763
+
764
+ metrics.set("neighbouring_paragraphs_ops", len(augmented_context.paragraphs))
765
+
766
+
767
+ def get_text_position(
768
+ paragraph_id: ParagraphId, index: int, field_metadata: FieldComputedMetadata
769
+ ) -> TextPosition | None:
770
+ if paragraph_id.field_id.subfield_id:
771
+ metadata = field_metadata.split_metadata[paragraph_id.field_id.subfield_id]
772
+ else:
773
+ metadata = field_metadata.metadata
774
+ try:
775
+ pmetadata = metadata.paragraphs[index]
776
+ except IndexError:
777
+ return None
778
+ page_number = None
779
+ if pmetadata.HasField("page"):
780
+ page_number = pmetadata.page.page
781
+ return TextPosition(
782
+ page_number=page_number,
783
+ index=index,
784
+ start=pmetadata.start,
785
+ end=pmetadata.end,
786
+ start_seconds=list(pmetadata.start_seconds),
787
+ end_seconds=list(pmetadata.end_seconds),
788
+ )
789
+
790
+
791
+ def get_neighbouring_indices(
792
+ index: int, before: int, after: int, field_pids: list[ParagraphId]
793
+ ) -> list[int]:
794
+ lb_index = max(0, index - before)
795
+ ub_index = min(len(field_pids), index + after + 1)
796
+ return list(range(lb_index, index)) + list(range(index + 1, ub_index))
797
+
798
+
799
+ async def conversation_prompt_context(
800
+ context: CappedPromptContext,
801
+ kbid: str,
802
+ ordered_paragraphs: list[FindParagraph],
803
+ strategy: ConversationalStrategy,
804
+ visual_llm: bool,
805
+ metrics: Metrics,
806
+ augmented_context: AugmentedContext,
807
+ ):
808
+ analyzed_fields: list[str] = []
809
+ ops = 0
810
+ async with get_driver().ro_transaction() as txn:
811
+ storage = await get_storage()
812
+ kb = KnowledgeBoxORM(txn, storage, kbid)
813
+ for paragraph in ordered_paragraphs:
814
+ if paragraph.id not in context:
815
+ context[paragraph.id] = _clean_paragraph_text(paragraph)
816
+
817
+ # If the paragraph is a conversation and it matches semantically, we assume we
818
+ # have matched with the question, therefore try to include the answer to the
819
+ # context by pulling the next few messages of the conversation field
820
+ rid, field_type, field_id, mident = paragraph.id.split("/")[:4]
821
+ if field_type == "c" and paragraph.score_type in (
822
+ SCORE_TYPE.VECTOR,
823
+ SCORE_TYPE.BOTH,
824
+ SCORE_TYPE.BM25,
825
+ ):
826
+ field_unique_id = "-".join([rid, field_type, field_id])
827
+ if field_unique_id in analyzed_fields:
828
+ continue
829
+ resource = await kb.get(rid)
830
+ if resource is None: # pragma: no cover
831
+ continue
832
+
833
+ field_obj: Conversation = await resource.get_field(
834
+ field_id, FIELD_TYPE_STR_TO_PB["c"], load=True
835
+ ) # type: ignore
836
+ cmetadata = await field_obj.get_metadata()
837
+
838
+ attachments: list[resources_pb2.FieldRef] = []
839
+ if strategy.full:
840
+ ops += 5
841
+ extracted_text = await field_obj.get_extracted_text()
842
+ for current_page in range(1, cmetadata.pages + 1):
843
+ conv = await field_obj.db_get_value(current_page)
844
+
845
+ for message in conv.messages:
846
+ ident = message.ident
847
+ if extracted_text is not None:
848
+ text = extracted_text.split_text.get(ident, message.content.text.strip())
849
+ else:
850
+ text = message.content.text.strip()
851
+ pid = f"{rid}/{field_type}/{field_id}/{ident}/0-{len(text)}"
852
+ attachments.extend(message.content.attachments_fields)
853
+ if pid in context:
854
+ continue
855
+ context[pid] = text
856
+ augmented_context.paragraphs[pid] = AugmentedTextBlock(
857
+ id=pid,
858
+ text=text,
859
+ parent=paragraph.id,
860
+ augmentation_type=TextBlockAugmentationType.CONVERSATION,
861
+ )
862
+ else:
863
+ # Add first message
864
+ extracted_text = await field_obj.get_extracted_text()
865
+ first_page = await field_obj.db_get_value()
866
+ if len(first_page.messages) > 0:
867
+ message = first_page.messages[0]
868
+ ident = message.ident
869
+ if extracted_text is not None:
870
+ text = extracted_text.split_text.get(ident, message.content.text.strip())
871
+ else:
872
+ text = message.content.text.strip()
873
+ attachments.extend(message.content.attachments_fields)
874
+ pid = f"{rid}/{field_type}/{field_id}/{ident}/0-{len(text)}"
875
+ if pid in context:
876
+ continue
877
+ context[pid] = text
878
+ augmented_context.paragraphs[pid] = AugmentedTextBlock(
879
+ id=pid,
880
+ text=text,
881
+ parent=paragraph.id,
882
+ augmentation_type=TextBlockAugmentationType.CONVERSATION,
883
+ )
884
+
885
+ messages: Deque[resources_pb2.Message] = deque(maxlen=strategy.max_messages)
886
+
887
+ pending = -1
888
+ for page in range(1, cmetadata.pages + 1):
889
+ # Collect the messages with the window asked by the user arround the match paragraph
890
+ conv = await field_obj.db_get_value(page)
891
+ for message in conv.messages:
892
+ messages.append(message)
893
+ if pending > 0:
894
+ pending -= 1
895
+ if message.ident == mident:
896
+ pending = (strategy.max_messages - 1) // 2
897
+ if pending == 0:
898
+ break
899
+ if pending == 0:
900
+ break
901
+
902
+ for message in messages:
903
+ ops += 1
904
+ text = message.content.text.strip()
905
+ attachments.extend(message.content.attachments_fields)
906
+ pid = f"{rid}/{field_type}/{field_id}/{message.ident}/0-{len(message.content.text)}"
907
+ if pid in context:
908
+ continue
909
+ context[pid] = text
910
+ augmented_context.paragraphs[pid] = AugmentedTextBlock(
911
+ id=pid,
912
+ text=text,
913
+ parent=paragraph.id,
914
+ augmentation_type=TextBlockAugmentationType.CONVERSATION,
915
+ )
916
+
917
+ if strategy.attachments_text:
918
+ # add on the context the images if vlm enabled
919
+ for attachment in attachments:
920
+ ops += 1
921
+ field: File = await resource.get_field(
922
+ attachment.field_id, attachment.field_type, load=True
923
+ ) # type: ignore
924
+ extracted_text = await field.get_extracted_text()
925
+ if extracted_text is not None:
926
+ attachment_field_type = FIELD_TYPE_PB_TO_STR[attachment.field_type]
927
+ pid = f"{rid}/{attachment_field_type}/{attachment.field_id}/0-{len(extracted_text.text)}"
928
+ if pid in context:
929
+ continue
930
+ text = f"Attachment {attachment.field_id}: {extracted_text.text}\n\n"
931
+ context[pid] = text
932
+ augmented_context.paragraphs[pid] = AugmentedTextBlock(
933
+ id=pid,
934
+ text=text,
935
+ parent=paragraph.id,
936
+ augmentation_type=TextBlockAugmentationType.CONVERSATION,
937
+ )
938
+
939
+ if strategy.attachments_images and visual_llm:
940
+ for attachment in attachments:
941
+ ops += 1
942
+ file_field: File = await resource.get_field(
943
+ attachment.field_id, attachment.field_type, load=True
944
+ ) # type: ignore
945
+ image = await get_file_thumbnail_image(file_field)
946
+ if image is not None:
947
+ pid = f"{rid}/f/{attachment.field_id}/0-0"
948
+ context.images[pid] = image
949
+
950
+ analyzed_fields.append(field_unique_id)
951
+ metrics.set("conversation_ops", ops)
952
+
953
+
954
+ async def hierarchy_prompt_context(
955
+ context: CappedPromptContext,
956
+ kbid: str,
957
+ ordered_paragraphs: list[FindParagraph],
958
+ strategy: HierarchyResourceStrategy,
959
+ metrics: Metrics,
960
+ augmented_context: AugmentedContext,
961
+ ) -> None:
962
+ """
963
+ This function will get the paragraph texts (possibly with extra characters, if extra_characters > 0) and then
964
+ craft a context with all paragraphs of the same resource grouped together. Moreover, on each group of paragraphs,
965
+ it includes the resource title and summary so that the LLM can have a better understanding of the context.
966
+ """
967
+ paragraphs_extra_characters = max(strategy.count, 0)
968
+ # Make a copy of the ordered paragraphs to avoid modifying the original list, which is returned
969
+ # in the response to the user
970
+ ordered_paragraphs_copy = copy.deepcopy(ordered_paragraphs)
971
+ resources: dict[str, ExtraCharsParagraph] = {}
972
+
973
+ # Iterate paragraphs to get extended text
974
+ for paragraph in ordered_paragraphs_copy:
975
+ paragraph_id = ParagraphId.from_string(paragraph.id)
976
+ extended_paragraph_text = paragraph.text
977
+ if paragraphs_extra_characters > 0:
978
+ extended_paragraph_id = ParagraphId(
979
+ field_id=paragraph_id.field_id,
980
+ paragraph_start=paragraph_id.paragraph_start,
981
+ paragraph_end=paragraph_id.paragraph_end + paragraphs_extra_characters,
982
+ )
983
+ extended_paragraph_text = await get_paragraph_text(
984
+ kbid=kbid,
985
+ paragraph_id=extended_paragraph_id,
986
+ log_on_missing_field=True,
987
+ )
988
+ rid = paragraph_id.rid
989
+ if rid not in resources:
990
+ # Get the title and the summary of the resource
991
+ title_text = await get_paragraph_text(
992
+ kbid=kbid,
993
+ paragraph_id=ParagraphId(
994
+ field_id=FieldId(
995
+ rid=rid,
996
+ type="a",
997
+ key="title",
998
+ ),
999
+ paragraph_start=0,
1000
+ paragraph_end=500,
1001
+ ),
1002
+ log_on_missing_field=False,
1003
+ )
1004
+ summary_text = await get_paragraph_text(
1005
+ kbid=kbid,
1006
+ paragraph_id=ParagraphId(
1007
+ field_id=FieldId(
1008
+ rid=rid,
1009
+ type="a",
1010
+ key="summary",
1011
+ ),
1012
+ paragraph_start=0,
1013
+ paragraph_end=1000,
1014
+ ),
1015
+ log_on_missing_field=False,
1016
+ )
1017
+ resources[rid] = ExtraCharsParagraph(
1018
+ title=title_text,
1019
+ summary=summary_text,
1020
+ paragraphs=[(paragraph, extended_paragraph_text)],
1021
+ )
1022
+ else:
1023
+ resources[rid].paragraphs.append((paragraph, extended_paragraph_text))
1024
+
1025
+ metrics.set("hierarchy_ops", len(resources))
1026
+ augmented_paragraphs = set()
1027
+
1028
+ # Modify the first paragraph of each resource to include the title and summary of the resource, as well as the
1029
+ # extended paragraph text of all the paragraphs in the resource.
1030
+ for values in resources.values():
1031
+ title_text = values.title
1032
+ summary_text = values.summary
1033
+ first_paragraph = None
1034
+ text_with_hierarchy = ""
1035
+ for paragraph, extended_paragraph_text in values.paragraphs:
1036
+ if first_paragraph is None:
1037
+ first_paragraph = paragraph
1038
+ text_with_hierarchy += "\n EXTRACTED BLOCK: \n " + extended_paragraph_text + " \n\n "
1039
+ # All paragraphs of the resource are cleared except the first one, which will be the
1040
+ # one containing the whole hierarchy information
1041
+ paragraph.text = ""
1042
+
1043
+ if first_paragraph is not None:
1044
+ # The first paragraph is the only one holding the hierarchy information
1045
+ first_paragraph.text = f"DOCUMENT: {title_text} \n SUMMARY: {summary_text} \n RESOURCE CONTENT: {text_with_hierarchy}"
1046
+ augmented_paragraphs.add(first_paragraph.id)
1047
+
1048
+ # Now that the paragraphs have been modified, we can add them to the context
1049
+ for paragraph in ordered_paragraphs_copy:
1050
+ if paragraph.text == "":
1051
+ # Skip paragraphs that were cleared in the hierarchy expansion
1052
+ continue
1053
+ paragraph_text = _clean_paragraph_text(paragraph)
1054
+ context[paragraph.id] = paragraph_text
1055
+ if paragraph.id in augmented_paragraphs:
1056
+ pid = ParagraphId.from_string(paragraph.id)
1057
+ augmented_context.paragraphs[pid.full()] = AugmentedTextBlock(
1058
+ id=pid.full(), text=paragraph_text, augmentation_type=TextBlockAugmentationType.HIERARCHY
1059
+ )
1060
+ return
1061
+
1062
+
1063
+ class PromptContextBuilder:
1064
+ """
1065
+ Builds the context for the LLM prompt.
1066
+ """
1067
+
1068
+ def __init__(
1069
+ self,
1070
+ kbid: str,
1071
+ ordered_paragraphs: list[FindParagraph],
1072
+ resource: str | None = None,
1073
+ user_context: list[str] | None = None,
1074
+ user_image_context: list[Image] | None = None,
1075
+ strategies: Sequence[RagStrategy] | None = None,
1076
+ image_strategies: Sequence[ImageRagStrategy] | None = None,
1077
+ max_context_characters: int | None = None,
1078
+ visual_llm: bool = False,
1079
+ query_image: Image | None = None,
1080
+ metrics: Metrics = Metrics("prompt_context_builder"),
1081
+ ):
1082
+ self.kbid = kbid
1083
+ self.ordered_paragraphs = ordered_paragraphs
1084
+ self.resource = resource
1085
+ self.user_context = user_context
1086
+ self.user_image_context = user_image_context
1087
+ self.strategies = strategies
1088
+ self.image_strategies = image_strategies
1089
+ self.max_context_characters = max_context_characters
1090
+ self.visual_llm = visual_llm
1091
+ self.metrics = metrics
1092
+ self.query_image = query_image
1093
+ self.augmented_context = AugmentedContext(paragraphs={}, fields={})
1094
+
1095
+ def prepend_user_context(self, context: CappedPromptContext):
1096
+ # Chat extra context passed by the user is the most important, therefore
1097
+ # it is added first, followed by the found text blocks in order of relevance
1098
+ for i, text_block in enumerate(self.user_context or []):
1099
+ context[f"USER_CONTEXT_{i}"] = text_block
1100
+ # Add the query image as part of the image context
1101
+ if self.query_image is not None:
1102
+ context.images["QUERY_IMAGE"] = self.query_image
1103
+ else:
1104
+ for i, image in enumerate(self.user_image_context or []):
1105
+ context.images[f"USER_IMAGE_CONTEXT_{i}"] = image
1106
+
1107
+ async def build(
1108
+ self,
1109
+ ) -> tuple[PromptContext, PromptContextOrder, PromptContextImages, AugmentedContext]:
1110
+ ccontext = CappedPromptContext(max_size=self.max_context_characters)
1111
+ self.prepend_user_context(ccontext)
1112
+ await self._build_context(ccontext)
1113
+ if self.visual_llm and not self.query_image:
1114
+ await self._build_context_images(ccontext)
1115
+ context = ccontext.cap()
1116
+ context_images = ccontext.images
1117
+ context_order = {text_block_id: order for order, text_block_id in enumerate(context.keys())}
1118
+ return context, context_order, context_images, self.augmented_context
1119
+
1120
+ async def _build_context_images(self, context: CappedPromptContext) -> None:
1121
+ ops = 0
1122
+ if self.image_strategies is None or len(self.image_strategies) == 0:
1123
+ # Nothing to do
1124
+ return
1125
+ page_image_strategy: PageImageStrategy | None = None
1126
+ max_page_images = 5
1127
+ table_image_strategy: TableImageStrategy | None = None
1128
+ paragraph_image_strategy: ParagraphImageStrategy | None = None
1129
+ for strategy in self.image_strategies:
1130
+ if strategy.name == ImageRagStrategyName.PAGE_IMAGE:
1131
+ if page_image_strategy is None:
1132
+ page_image_strategy = cast(PageImageStrategy, strategy)
1133
+ if page_image_strategy.count is not None:
1134
+ max_page_images = page_image_strategy.count
1135
+ elif strategy.name == ImageRagStrategyName.TABLES:
1136
+ if table_image_strategy is None:
1137
+ table_image_strategy = cast(TableImageStrategy, strategy)
1138
+ elif strategy.name == ImageRagStrategyName.PARAGRAPH_IMAGE:
1139
+ if paragraph_image_strategy is None:
1140
+ paragraph_image_strategy = cast(ParagraphImageStrategy, strategy)
1141
+ else: # pragma: no cover
1142
+ logger.warning(
1143
+ "Unknown image strategy",
1144
+ extra={"strategy": strategy.name, "kbid": self.kbid},
1145
+ )
1146
+ page_images_added = 0
1147
+ for paragraph in self.ordered_paragraphs:
1148
+ pid = ParagraphId.from_string(paragraph.id)
1149
+ paragraph_page_number = get_paragraph_page_number(paragraph)
1150
+ if (
1151
+ page_image_strategy is not None
1152
+ and page_images_added < max_page_images
1153
+ and paragraph_page_number is not None
1154
+ ):
1155
+ # page_image_id: rid/f/myfield/0
1156
+ page_image_id = "/".join([pid.field_id.full(), str(paragraph_page_number)])
1157
+ if page_image_id not in context.images:
1158
+ image = await get_page_image(self.kbid, pid, paragraph_page_number)
1159
+ if image is not None:
1160
+ ops += 1
1161
+ context.images[page_image_id] = image
1162
+ page_images_added += 1
1163
+ else:
1164
+ logger.warning(
1165
+ f"Could not retrieve image for paragraph from storage",
1166
+ extra={
1167
+ "kbid": self.kbid,
1168
+ "paragraph": pid.full(),
1169
+ "page_number": paragraph_page_number,
1170
+ },
1171
+ )
1172
+
1173
+ add_table = table_image_strategy is not None and paragraph.is_a_table
1174
+ add_paragraph = paragraph_image_strategy is not None and not paragraph.is_a_table
1175
+ if (add_table or add_paragraph) and (
1176
+ paragraph.reference is not None and paragraph.reference != ""
1177
+ ):
1178
+ pimage = await get_paragraph_image(self.kbid, pid, paragraph.reference)
1179
+ if pimage is not None:
1180
+ ops += 1
1181
+ context.images[paragraph.id] = pimage
1182
+ else:
1183
+ logger.warning(
1184
+ f"Could not retrieve image for paragraph from storage",
1185
+ extra={
1186
+ "kbid": self.kbid,
1187
+ "paragraph": pid.full(),
1188
+ "reference": paragraph.reference,
1189
+ },
1190
+ )
1191
+ self.metrics.set("image_ops", ops)
1192
+
1193
+ async def _build_context(self, context: CappedPromptContext) -> None:
1194
+ if self.strategies is None or len(self.strategies) == 0:
1195
+ # When no strategy is specified, use the default one
1196
+ await default_prompt_context(context, self.kbid, self.ordered_paragraphs)
1197
+ return
1198
+ else:
1199
+ # Add the paragraphs to the context and then apply the strategies
1200
+ for paragraph in self.ordered_paragraphs:
1201
+ context[paragraph.id] = _clean_paragraph_text(paragraph)
1202
+
1203
+ strategies_not_handled_here = [
1204
+ RagStrategyName.PREQUERIES,
1205
+ RagStrategyName.GRAPH,
1206
+ ]
1207
+
1208
+ full_resource: FullResourceStrategy | None = None
1209
+ hierarchy: HierarchyResourceStrategy | None = None
1210
+ neighbouring_paragraphs: NeighbouringParagraphsStrategy | None = None
1211
+ field_extension: FieldExtensionStrategy | None = None
1212
+ metadata_extension: MetadataExtensionStrategy | None = None
1213
+ conversational_strategy: ConversationalStrategy | None = None
1214
+ for strategy in self.strategies:
1215
+ if strategy.name == RagStrategyName.FIELD_EXTENSION:
1216
+ field_extension = cast(FieldExtensionStrategy, strategy)
1217
+ elif strategy.name == RagStrategyName.CONVERSATION:
1218
+ conversational_strategy = cast(ConversationalStrategy, strategy)
1219
+ elif strategy.name == RagStrategyName.FULL_RESOURCE:
1220
+ full_resource = cast(FullResourceStrategy, strategy)
1221
+ if self.resource: # pragma: no cover
1222
+ # When the retrieval is scoped to a specific resource
1223
+ # the full resource strategy only includes that resource
1224
+ full_resource.count = 1
1225
+ elif strategy.name == RagStrategyName.HIERARCHY:
1226
+ hierarchy = cast(HierarchyResourceStrategy, strategy)
1227
+ elif strategy.name == RagStrategyName.NEIGHBOURING_PARAGRAPHS:
1228
+ neighbouring_paragraphs = cast(NeighbouringParagraphsStrategy, strategy)
1229
+ elif strategy.name == RagStrategyName.METADATA_EXTENSION:
1230
+ metadata_extension = cast(MetadataExtensionStrategy, strategy)
1231
+ elif strategy.name not in strategies_not_handled_here: # pragma: no cover
1232
+ # Prequeries and graph are not handled here
1233
+ logger.warning(
1234
+ "Unknown rag strategy",
1235
+ extra={"strategy": strategy.name, "kbid": self.kbid},
1236
+ )
1237
+
1238
+ if full_resource:
1239
+ # When full resoure is enabled, only metadata extension is allowed.
1240
+ await full_resource_prompt_context(
1241
+ context,
1242
+ self.kbid,
1243
+ self.ordered_paragraphs,
1244
+ self.resource,
1245
+ full_resource,
1246
+ self.metrics,
1247
+ self.augmented_context,
1248
+ )
1249
+ if metadata_extension:
1250
+ await extend_prompt_context_with_metadata(
1251
+ context,
1252
+ self.kbid,
1253
+ metadata_extension,
1254
+ self.metrics,
1255
+ self.augmented_context,
1256
+ )
1257
+ return
1258
+
1259
+ if hierarchy:
1260
+ await hierarchy_prompt_context(
1261
+ context,
1262
+ self.kbid,
1263
+ self.ordered_paragraphs,
1264
+ hierarchy,
1265
+ self.metrics,
1266
+ self.augmented_context,
1267
+ )
1268
+ if neighbouring_paragraphs:
1269
+ await neighbouring_paragraphs_prompt_context(
1270
+ context,
1271
+ self.kbid,
1272
+ self.ordered_paragraphs,
1273
+ neighbouring_paragraphs,
1274
+ self.metrics,
1275
+ self.augmented_context,
1276
+ )
1277
+ if field_extension:
1278
+ await field_extension_prompt_context(
1279
+ context,
1280
+ self.kbid,
1281
+ self.ordered_paragraphs,
1282
+ field_extension,
1283
+ self.metrics,
1284
+ self.augmented_context,
1285
+ )
1286
+ if conversational_strategy:
1287
+ await conversation_prompt_context(
1288
+ context,
1289
+ self.kbid,
1290
+ self.ordered_paragraphs,
1291
+ conversational_strategy,
1292
+ self.visual_llm,
1293
+ self.metrics,
1294
+ self.augmented_context,
1295
+ )
1296
+ if metadata_extension:
1297
+ await extend_prompt_context_with_metadata(
1298
+ context,
1299
+ self.kbid,
1300
+ metadata_extension,
1301
+ self.metrics,
1302
+ self.augmented_context,
1303
+ )
1304
+
1305
+
1306
+ def get_paragraph_page_number(paragraph: FindParagraph) -> int | None:
1307
+ if not paragraph.page_with_visual:
1308
+ return None
1309
+ if paragraph.position is None:
1310
+ return None
1311
+ return paragraph.position.page_number
1312
+
1313
+
1314
+ @dataclass
1315
+ class ExtraCharsParagraph:
1316
+ title: str
1317
+ summary: str
1318
+ paragraphs: list[tuple[FindParagraph, str]]
1319
+
1320
+
1321
+ def _clean_paragraph_text(paragraph: FindParagraph) -> str:
1322
+ text = paragraph.text.strip()
1323
+ # Do not send highlight marks on prompt context
1324
+ text = text.replace("<mark>", "").replace("</mark>", "")
1325
+ return text
1326
+
1327
+
1328
+ # COPY from hydrator/__init__.py that has been refactored and removed
1329
+
1330
+
1331
+ hydrator_observer = Observer("hydrator", labels={"type": ""})
1332
+
1333
+
1334
+ @hydrator_observer.wrap({"type": "resource_text"})
1335
+ async def hydrate_resource_text(
1336
+ kbid: str, rid: str, *, max_concurrent_tasks: int
1337
+ ) -> list[tuple[FieldId, str]]:
1338
+ resource = await cache.get_resource(kbid, rid)
1339
+ if resource is None: # pragma: no cover
1340
+ return []
1341
+
1342
+ # Schedule the extraction of the text of each field in the resource
1343
+ async with get_driver().ro_transaction() as txn:
1344
+ resource.txn = txn
1345
+ runner = ConcurrentRunner(max_tasks=max_concurrent_tasks)
1346
+ for field_type, field_key in await resource.get_fields(force=True):
1347
+ field_id = FieldId.from_pb(rid, field_type, field_key)
1348
+ runner.schedule(hydrate_field_text(kbid, field_id))
1349
+
1350
+ # Include the summary aswell
1351
+ runner.schedule(hydrate_field_text(kbid, FieldId(rid=rid, type="a", key="summary")))
1352
+
1353
+ # Wait for the results
1354
+ field_extracted_texts = await runner.wait()
1355
+
1356
+ return [text for text in field_extracted_texts if text is not None]
1357
+
1358
+
1359
+ @hydrator_observer.wrap({"type": "field_text"})
1360
+ async def hydrate_field_text(
1361
+ kbid: str,
1362
+ field_id: FieldId,
1363
+ ) -> tuple[FieldId, str] | None:
1364
+ field = await cache.get_field(kbid, field_id)
1365
+ if field is None: # pragma: no cover
1366
+ return None
1367
+
1368
+ extracted_text_pb = await cache.get_field_extracted_text(field)
1369
+ if extracted_text_pb is None: # pragma: no cover
1370
+ return None
1371
+
1372
+ if field_id.subfield_id:
1373
+ return field_id, extracted_text_pb.split_text[field_id.subfield_id]
1374
+ else:
1375
+ return field_id, extracted_text_pb.text