nucliadb 6.7.2.post4874__py3-none-any.whl → 6.10.0.post5705__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (246) hide show
  1. migrations/0023_backfill_pg_catalog.py +8 -4
  2. migrations/0028_extracted_vectors_reference.py +1 -1
  3. migrations/0029_backfill_field_status.py +3 -4
  4. migrations/0032_remove_old_relations.py +2 -3
  5. migrations/0038_backfill_catalog_field_labels.py +8 -4
  6. migrations/0039_backfill_converation_splits_metadata.py +106 -0
  7. migrations/0040_migrate_search_configurations.py +79 -0
  8. migrations/0041_reindex_conversations.py +137 -0
  9. migrations/pg/0010_shards_index.py +34 -0
  10. nucliadb/search/api/v1/resource/utils.py → migrations/pg/0011_catalog_statistics.py +5 -6
  11. migrations/pg/0012_catalog_statistics_undo.py +26 -0
  12. nucliadb/backups/create.py +2 -15
  13. nucliadb/backups/restore.py +4 -15
  14. nucliadb/backups/tasks.py +4 -1
  15. nucliadb/common/back_pressure/cache.py +2 -3
  16. nucliadb/common/back_pressure/materializer.py +7 -13
  17. nucliadb/common/back_pressure/settings.py +6 -6
  18. nucliadb/common/back_pressure/utils.py +1 -0
  19. nucliadb/common/cache.py +9 -9
  20. nucliadb/common/catalog/__init__.py +79 -0
  21. nucliadb/common/catalog/dummy.py +36 -0
  22. nucliadb/common/catalog/interface.py +85 -0
  23. nucliadb/{search/search/pgcatalog.py → common/catalog/pg.py} +330 -232
  24. nucliadb/common/catalog/utils.py +56 -0
  25. nucliadb/common/cluster/manager.py +8 -23
  26. nucliadb/common/cluster/rebalance.py +484 -112
  27. nucliadb/common/cluster/rollover.py +36 -9
  28. nucliadb/common/cluster/settings.py +4 -9
  29. nucliadb/common/cluster/utils.py +34 -8
  30. nucliadb/common/context/__init__.py +7 -8
  31. nucliadb/common/context/fastapi.py +1 -2
  32. nucliadb/common/datamanagers/__init__.py +2 -4
  33. nucliadb/common/datamanagers/atomic.py +9 -2
  34. nucliadb/common/datamanagers/cluster.py +1 -2
  35. nucliadb/common/datamanagers/fields.py +3 -4
  36. nucliadb/common/datamanagers/kb.py +6 -6
  37. nucliadb/common/datamanagers/labels.py +2 -3
  38. nucliadb/common/datamanagers/resources.py +10 -33
  39. nucliadb/common/datamanagers/rollover.py +5 -7
  40. nucliadb/common/datamanagers/search_configurations.py +1 -2
  41. nucliadb/common/datamanagers/synonyms.py +1 -2
  42. nucliadb/common/datamanagers/utils.py +4 -4
  43. nucliadb/common/datamanagers/vectorsets.py +4 -4
  44. nucliadb/common/external_index_providers/base.py +32 -5
  45. nucliadb/common/external_index_providers/manager.py +5 -34
  46. nucliadb/common/external_index_providers/settings.py +1 -27
  47. nucliadb/common/filter_expression.py +129 -41
  48. nucliadb/common/http_clients/exceptions.py +8 -0
  49. nucliadb/common/http_clients/processing.py +16 -23
  50. nucliadb/common/http_clients/utils.py +3 -0
  51. nucliadb/common/ids.py +82 -58
  52. nucliadb/common/locking.py +1 -2
  53. nucliadb/common/maindb/driver.py +9 -8
  54. nucliadb/common/maindb/local.py +5 -5
  55. nucliadb/common/maindb/pg.py +9 -8
  56. nucliadb/common/nidx.py +22 -5
  57. nucliadb/common/vector_index_config.py +1 -1
  58. nucliadb/export_import/datamanager.py +4 -3
  59. nucliadb/export_import/exporter.py +11 -19
  60. nucliadb/export_import/importer.py +13 -6
  61. nucliadb/export_import/tasks.py +2 -0
  62. nucliadb/export_import/utils.py +6 -18
  63. nucliadb/health.py +2 -2
  64. nucliadb/ingest/app.py +8 -8
  65. nucliadb/ingest/consumer/consumer.py +8 -10
  66. nucliadb/ingest/consumer/pull.py +10 -8
  67. nucliadb/ingest/consumer/service.py +5 -30
  68. nucliadb/ingest/consumer/shard_creator.py +16 -5
  69. nucliadb/ingest/consumer/utils.py +1 -1
  70. nucliadb/ingest/fields/base.py +37 -49
  71. nucliadb/ingest/fields/conversation.py +55 -9
  72. nucliadb/ingest/fields/exceptions.py +1 -2
  73. nucliadb/ingest/fields/file.py +22 -8
  74. nucliadb/ingest/fields/link.py +7 -7
  75. nucliadb/ingest/fields/text.py +2 -3
  76. nucliadb/ingest/orm/brain_v2.py +89 -57
  77. nucliadb/ingest/orm/broker_message.py +2 -4
  78. nucliadb/ingest/orm/entities.py +10 -209
  79. nucliadb/ingest/orm/index_message.py +128 -113
  80. nucliadb/ingest/orm/knowledgebox.py +91 -59
  81. nucliadb/ingest/orm/processor/auditing.py +1 -3
  82. nucliadb/ingest/orm/processor/data_augmentation.py +1 -2
  83. nucliadb/ingest/orm/processor/processor.py +98 -153
  84. nucliadb/ingest/orm/processor/sequence_manager.py +1 -2
  85. nucliadb/ingest/orm/resource.py +82 -71
  86. nucliadb/ingest/orm/utils.py +1 -1
  87. nucliadb/ingest/partitions.py +12 -1
  88. nucliadb/ingest/processing.py +17 -17
  89. nucliadb/ingest/serialize.py +202 -145
  90. nucliadb/ingest/service/writer.py +15 -114
  91. nucliadb/ingest/settings.py +36 -15
  92. nucliadb/ingest/utils.py +1 -2
  93. nucliadb/learning_proxy.py +23 -26
  94. nucliadb/metrics_exporter.py +20 -6
  95. nucliadb/middleware/__init__.py +82 -1
  96. nucliadb/migrator/datamanager.py +4 -11
  97. nucliadb/migrator/migrator.py +1 -2
  98. nucliadb/migrator/models.py +1 -2
  99. nucliadb/migrator/settings.py +1 -2
  100. nucliadb/models/internal/augment.py +614 -0
  101. nucliadb/models/internal/processing.py +19 -19
  102. nucliadb/openapi.py +2 -2
  103. nucliadb/purge/__init__.py +3 -8
  104. nucliadb/purge/orphan_shards.py +1 -2
  105. nucliadb/reader/__init__.py +5 -0
  106. nucliadb/reader/api/models.py +6 -13
  107. nucliadb/reader/api/v1/download.py +59 -38
  108. nucliadb/reader/api/v1/export_import.py +4 -4
  109. nucliadb/reader/api/v1/knowledgebox.py +37 -9
  110. nucliadb/reader/api/v1/learning_config.py +33 -14
  111. nucliadb/reader/api/v1/resource.py +61 -9
  112. nucliadb/reader/api/v1/services.py +18 -14
  113. nucliadb/reader/app.py +3 -1
  114. nucliadb/reader/reader/notifications.py +1 -2
  115. nucliadb/search/api/v1/__init__.py +3 -0
  116. nucliadb/search/api/v1/ask.py +3 -4
  117. nucliadb/search/api/v1/augment.py +585 -0
  118. nucliadb/search/api/v1/catalog.py +15 -19
  119. nucliadb/search/api/v1/find.py +16 -22
  120. nucliadb/search/api/v1/hydrate.py +328 -0
  121. nucliadb/search/api/v1/knowledgebox.py +1 -2
  122. nucliadb/search/api/v1/predict_proxy.py +1 -2
  123. nucliadb/search/api/v1/resource/ask.py +28 -8
  124. nucliadb/search/api/v1/resource/ingestion_agents.py +5 -6
  125. nucliadb/search/api/v1/resource/search.py +9 -11
  126. nucliadb/search/api/v1/retrieve.py +130 -0
  127. nucliadb/search/api/v1/search.py +28 -32
  128. nucliadb/search/api/v1/suggest.py +11 -14
  129. nucliadb/search/api/v1/summarize.py +1 -2
  130. nucliadb/search/api/v1/utils.py +2 -2
  131. nucliadb/search/app.py +3 -2
  132. nucliadb/search/augmentor/__init__.py +21 -0
  133. nucliadb/search/augmentor/augmentor.py +232 -0
  134. nucliadb/search/augmentor/fields.py +704 -0
  135. nucliadb/search/augmentor/metrics.py +24 -0
  136. nucliadb/search/augmentor/paragraphs.py +334 -0
  137. nucliadb/search/augmentor/resources.py +238 -0
  138. nucliadb/search/augmentor/utils.py +33 -0
  139. nucliadb/search/lifecycle.py +3 -1
  140. nucliadb/search/predict.py +33 -19
  141. nucliadb/search/predict_models.py +8 -9
  142. nucliadb/search/requesters/utils.py +11 -10
  143. nucliadb/search/search/cache.py +19 -42
  144. nucliadb/search/search/chat/ask.py +131 -59
  145. nucliadb/search/search/chat/exceptions.py +3 -5
  146. nucliadb/search/search/chat/fetcher.py +201 -0
  147. nucliadb/search/search/chat/images.py +6 -4
  148. nucliadb/search/search/chat/old_prompt.py +1375 -0
  149. nucliadb/search/search/chat/parser.py +510 -0
  150. nucliadb/search/search/chat/prompt.py +563 -615
  151. nucliadb/search/search/chat/query.py +453 -32
  152. nucliadb/search/search/chat/rpc.py +85 -0
  153. nucliadb/search/search/fetch.py +3 -4
  154. nucliadb/search/search/filters.py +8 -11
  155. nucliadb/search/search/find.py +33 -31
  156. nucliadb/search/search/find_merge.py +124 -331
  157. nucliadb/search/search/graph_strategy.py +14 -12
  158. nucliadb/search/search/hydrator/__init__.py +49 -0
  159. nucliadb/search/search/hydrator/fields.py +217 -0
  160. nucliadb/search/search/hydrator/images.py +130 -0
  161. nucliadb/search/search/hydrator/paragraphs.py +323 -0
  162. nucliadb/search/search/hydrator/resources.py +60 -0
  163. nucliadb/search/search/ingestion_agents.py +5 -5
  164. nucliadb/search/search/merge.py +90 -94
  165. nucliadb/search/search/metrics.py +24 -7
  166. nucliadb/search/search/paragraphs.py +7 -9
  167. nucliadb/search/search/predict_proxy.py +44 -18
  168. nucliadb/search/search/query.py +14 -86
  169. nucliadb/search/search/query_parser/fetcher.py +51 -82
  170. nucliadb/search/search/query_parser/models.py +19 -48
  171. nucliadb/search/search/query_parser/old_filters.py +20 -19
  172. nucliadb/search/search/query_parser/parsers/ask.py +5 -6
  173. nucliadb/search/search/query_parser/parsers/catalog.py +7 -11
  174. nucliadb/search/search/query_parser/parsers/common.py +21 -13
  175. nucliadb/search/search/query_parser/parsers/find.py +6 -29
  176. nucliadb/search/search/query_parser/parsers/graph.py +18 -28
  177. nucliadb/search/search/query_parser/parsers/retrieve.py +207 -0
  178. nucliadb/search/search/query_parser/parsers/search.py +15 -56
  179. nucliadb/search/search/query_parser/parsers/unit_retrieval.py +8 -29
  180. nucliadb/search/search/rank_fusion.py +18 -13
  181. nucliadb/search/search/rerankers.py +6 -7
  182. nucliadb/search/search/retrieval.py +300 -0
  183. nucliadb/search/search/summarize.py +5 -6
  184. nucliadb/search/search/utils.py +3 -4
  185. nucliadb/search/settings.py +1 -2
  186. nucliadb/standalone/api_router.py +1 -1
  187. nucliadb/standalone/app.py +4 -3
  188. nucliadb/standalone/auth.py +5 -6
  189. nucliadb/standalone/lifecycle.py +2 -2
  190. nucliadb/standalone/run.py +5 -4
  191. nucliadb/standalone/settings.py +5 -6
  192. nucliadb/standalone/versions.py +3 -4
  193. nucliadb/tasks/consumer.py +13 -8
  194. nucliadb/tasks/models.py +2 -1
  195. nucliadb/tasks/producer.py +3 -3
  196. nucliadb/tasks/retries.py +8 -7
  197. nucliadb/train/api/utils.py +1 -3
  198. nucliadb/train/api/v1/shards.py +1 -2
  199. nucliadb/train/api/v1/trainset.py +1 -2
  200. nucliadb/train/app.py +1 -1
  201. nucliadb/train/generator.py +4 -4
  202. nucliadb/train/generators/field_classifier.py +2 -2
  203. nucliadb/train/generators/field_streaming.py +6 -6
  204. nucliadb/train/generators/image_classifier.py +2 -2
  205. nucliadb/train/generators/paragraph_classifier.py +2 -2
  206. nucliadb/train/generators/paragraph_streaming.py +2 -2
  207. nucliadb/train/generators/question_answer_streaming.py +2 -2
  208. nucliadb/train/generators/sentence_classifier.py +4 -10
  209. nucliadb/train/generators/token_classifier.py +3 -2
  210. nucliadb/train/generators/utils.py +6 -5
  211. nucliadb/train/nodes.py +3 -3
  212. nucliadb/train/resource.py +6 -8
  213. nucliadb/train/settings.py +3 -4
  214. nucliadb/train/types.py +11 -11
  215. nucliadb/train/upload.py +3 -2
  216. nucliadb/train/uploader.py +1 -2
  217. nucliadb/train/utils.py +1 -2
  218. nucliadb/writer/api/v1/export_import.py +4 -1
  219. nucliadb/writer/api/v1/field.py +15 -14
  220. nucliadb/writer/api/v1/knowledgebox.py +18 -56
  221. nucliadb/writer/api/v1/learning_config.py +5 -4
  222. nucliadb/writer/api/v1/resource.py +9 -20
  223. nucliadb/writer/api/v1/services.py +10 -132
  224. nucliadb/writer/api/v1/upload.py +73 -72
  225. nucliadb/writer/app.py +8 -2
  226. nucliadb/writer/resource/basic.py +12 -15
  227. nucliadb/writer/resource/field.py +43 -5
  228. nucliadb/writer/resource/origin.py +7 -0
  229. nucliadb/writer/settings.py +2 -3
  230. nucliadb/writer/tus/__init__.py +2 -3
  231. nucliadb/writer/tus/azure.py +5 -7
  232. nucliadb/writer/tus/dm.py +3 -3
  233. nucliadb/writer/tus/exceptions.py +3 -4
  234. nucliadb/writer/tus/gcs.py +15 -22
  235. nucliadb/writer/tus/s3.py +2 -3
  236. nucliadb/writer/tus/storage.py +3 -3
  237. {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/METADATA +10 -11
  238. nucliadb-6.10.0.post5705.dist-info/RECORD +410 -0
  239. nucliadb/common/datamanagers/entities.py +0 -139
  240. nucliadb/common/external_index_providers/pinecone.py +0 -894
  241. nucliadb/ingest/orm/processor/pgcatalog.py +0 -129
  242. nucliadb/search/search/hydrator.py +0 -197
  243. nucliadb-6.7.2.post4874.dist-info/RECORD +0 -383
  244. {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/WHEEL +0 -0
  245. {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/entry_points.txt +0 -0
  246. {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/top_level.txt +0 -0
@@ -17,34 +17,36 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
- import asyncio
21
20
  import copy
22
- from collections import deque
21
+ from collections.abc import Sequence
23
22
  from dataclasses import dataclass
24
- from typing import Deque, Dict, List, Optional, Sequence, Tuple, Union, cast
23
+ from typing import cast
25
24
 
26
25
  import yaml
27
26
  from pydantic import BaseModel
28
27
 
29
- from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB, FieldId, ParagraphId
30
- from nucliadb.common.maindb.utils import get_driver
31
- from nucliadb.common.models_utils import from_proto
32
- from nucliadb.ingest.fields.base import Field
33
- from nucliadb.ingest.fields.conversation import Conversation
34
- from nucliadb.ingest.fields.file import File
35
- from nucliadb.ingest.orm.knowledgebox import KnowledgeBox as KnowledgeBoxORM
36
- from nucliadb.search import logger
37
- from nucliadb.search.search import cache
38
- from nucliadb.search.search.chat.images import (
39
- get_file_thumbnail_image,
40
- get_page_image,
41
- get_paragraph_image,
28
+ import nucliadb_models
29
+ from nucliadb.common.ids import (
30
+ FIELD_TYPE_STR_TO_NAME,
31
+ FieldId,
32
+ ParagraphId,
42
33
  )
43
- from nucliadb.search.search.hydrator import hydrate_field_text, hydrate_resource_text
34
+ from nucliadb.search import logger
35
+ from nucliadb.search.search.chat import rpc
44
36
  from nucliadb.search.search.metrics import Metrics
45
- from nucliadb.search.search.paragraphs import get_paragraph_text
37
+ from nucliadb_models.augment import (
38
+ AugmentedConversationField,
39
+ AugmentedField,
40
+ AugmentedFileField,
41
+ AugmentFields,
42
+ AugmentParagraph,
43
+ AugmentParagraphs,
44
+ AugmentRequest,
45
+ AugmentResourceFields,
46
+ AugmentResources,
47
+ )
48
+ from nucliadb_models.common import FieldTypeName
46
49
  from nucliadb_models.labels import translate_alias_to_system_label
47
- from nucliadb_models.metadata import Extra, Origin
48
50
  from nucliadb_models.search import (
49
51
  SCORE_TYPE,
50
52
  AugmentedContext,
@@ -71,24 +73,9 @@ from nucliadb_models.search import (
71
73
  TextBlockAugmentationType,
72
74
  TextPosition,
73
75
  )
74
- from nucliadb_protos import resources_pb2
75
- from nucliadb_protos.resources_pb2 import ExtractedText, FieldComputedMetadata
76
- from nucliadb_utils.asyncio_utils import run_concurrently
77
- from nucliadb_utils.utilities import get_storage
78
-
79
- MAX_RESOURCE_TASKS = 5
80
- MAX_RESOURCE_FIELD_TASKS = 4
81
-
82
-
83
- # Number of messages to pull after a match in a message
84
- # The hope here is it will be enough to get the answer to the question.
85
- CONVERSATION_MESSAGE_CONTEXT_EXPANSION = 15
86
-
87
- TextBlockId = Union[ParagraphId, FieldId]
76
+ from nucliadb_protos.resources_pb2 import FieldComputedMetadata
88
77
 
89
-
90
- class ParagraphIdNotFoundInExtractedMetadata(Exception):
91
- pass
78
+ TextBlockId = ParagraphId | FieldId
92
79
 
93
80
 
94
81
  class CappedPromptContext:
@@ -97,7 +84,7 @@ class CappedPromptContext:
97
84
  and automatically trim data that exceeds the limit when it's being set on the dictionary.
98
85
  """
99
86
 
100
- def __init__(self, max_size: Optional[int]):
87
+ def __init__(self, max_size: int | None):
101
88
  self.output: PromptContext = {}
102
89
  self.images: PromptContextImages = {}
103
90
  self.max_size = max_size
@@ -158,79 +145,6 @@ class CappedPromptContext:
158
145
  return self.output
159
146
 
160
147
 
161
- async def get_next_conversation_messages(
162
- *,
163
- field_obj: Conversation,
164
- page: int,
165
- start_idx: int,
166
- num_messages: int,
167
- message_type: Optional[resources_pb2.Message.MessageType.ValueType] = None,
168
- msg_to: Optional[str] = None,
169
- ) -> List[resources_pb2.Message]:
170
- output = []
171
- cmetadata = await field_obj.get_metadata()
172
- for current_page in range(page, cmetadata.pages + 1):
173
- conv = await field_obj.db_get_value(current_page)
174
- for message in conv.messages[start_idx:]:
175
- if message_type is not None and message.type != message_type: # pragma: no cover
176
- continue
177
- if msg_to is not None and msg_to not in message.to: # pragma: no cover
178
- continue
179
- output.append(message)
180
- if len(output) >= num_messages:
181
- return output
182
- start_idx = 0
183
-
184
- return output
185
-
186
-
187
- async def find_conversation_message(
188
- field_obj: Conversation, mident: str
189
- ) -> tuple[Optional[resources_pb2.Message], int, int]:
190
- cmetadata = await field_obj.get_metadata()
191
- for page in range(1, cmetadata.pages + 1):
192
- conv = await field_obj.db_get_value(page)
193
- for idx, message in enumerate(conv.messages):
194
- if message.ident == mident:
195
- return message, page, idx
196
- return None, -1, -1
197
-
198
-
199
- async def get_expanded_conversation_messages(
200
- *,
201
- kb: KnowledgeBoxORM,
202
- rid: str,
203
- field_id: str,
204
- mident: str,
205
- max_messages: int = CONVERSATION_MESSAGE_CONTEXT_EXPANSION,
206
- ) -> list[resources_pb2.Message]:
207
- resource = await kb.get(rid)
208
- if resource is None: # pragma: no cover
209
- return []
210
- field_obj: Conversation = await resource.get_field(field_id, FIELD_TYPE_STR_TO_PB["c"], load=True) # type: ignore
211
- found_message, found_page, found_idx = await find_conversation_message(
212
- field_obj=field_obj, mident=mident
213
- )
214
- if found_message is None: # pragma: no cover
215
- return []
216
- elif found_message.type == resources_pb2.Message.MessageType.QUESTION:
217
- # only try to get answer if it was a question
218
- return await get_next_conversation_messages(
219
- field_obj=field_obj,
220
- page=found_page,
221
- start_idx=found_idx + 1,
222
- num_messages=1,
223
- message_type=resources_pb2.Message.MessageType.ANSWER,
224
- )
225
- else:
226
- return await get_next_conversation_messages(
227
- field_obj=field_obj,
228
- page=found_page,
229
- start_idx=found_idx + 1,
230
- num_messages=max_messages,
231
- )
232
-
233
-
234
148
  async def default_prompt_context(
235
149
  context: CappedPromptContext,
236
150
  kbid: str,
@@ -245,35 +159,59 @@ async def default_prompt_context(
245
159
  - User context is inserted first, in order of appearance.
246
160
  - Using an dict prevents from duplicates pulled in through conversation expansion.
247
161
  """
248
- # Sort retrieved paragraphs by decreasing order (most relevant first)
249
- async with get_driver().ro_transaction() as txn:
250
- storage = await get_storage()
251
- kb = KnowledgeBoxORM(txn, storage, kbid)
252
- for paragraph in ordered_paragraphs:
253
- context[paragraph.id] = _clean_paragraph_text(paragraph)
254
162
 
255
- # If the paragraph is a conversation and it matches semantically, we assume we
256
- # have matched with the question, therefore try to include the answer to the
257
- # context by pulling the next few messages of the conversation field
258
- rid, field_type, field_id, mident = paragraph.id.split("/")[:4]
259
- if field_type == "c" and paragraph.score_type in (
260
- SCORE_TYPE.VECTOR,
261
- SCORE_TYPE.BOTH,
262
- ):
263
- expanded_msgs = await get_expanded_conversation_messages(
264
- kb=kb, rid=rid, field_id=field_id, mident=mident
265
- )
266
- for msg in expanded_msgs:
267
- text = msg.content.text.strip()
268
- pid = f"{rid}/{field_type}/{field_id}/{msg.ident}/0-{len(msg.content.text) + 1}"
269
- context[pid] = text
163
+ conversations = []
164
+
165
+ for paragraph in ordered_paragraphs:
166
+ context[paragraph.id] = _clean_paragraph_text(paragraph)
167
+
168
+ # If the paragraph is a conversation and it matches semantically, we
169
+ # assume we have matched with the question, therefore try to include the
170
+ # answer to the context by pulling the next few messages of the
171
+ # conversation field
172
+ rid, field_type, field_id, mident = paragraph.id.split("/")[:4]
173
+ # FIXME: a semantic paragraph can have reranker score. Once we
174
+ # refactor and have access to the score history, we can fix this
175
+ if field_type == "c" and paragraph.score_type in (
176
+ SCORE_TYPE.VECTOR,
177
+ SCORE_TYPE.BOTH,
178
+ ):
179
+ conversations.append(f"{rid}/{field_type}/{field_id}/{mident}")
180
+
181
+ augment = AugmentRequest(
182
+ fields=[
183
+ AugmentFields(
184
+ given=[id for id in conversations],
185
+ conversation_answer_or_messages_after=True,
186
+ ),
187
+ ]
188
+ )
189
+ augmented = await rpc.augment(kbid, augment)
190
+
191
+ for id in conversations:
192
+ conversation_id = FieldId.from_string(id)
193
+
194
+ augmented_field = augmented.fields.get(conversation_id.full_without_subfield())
195
+ if augmented_field is None or not isinstance(augmented_field, AugmentedConversationField):
196
+ continue
197
+
198
+ for message in augmented_field.messages or []:
199
+ if message.text is None:
200
+ continue
201
+
202
+ message_id = copy.copy(conversation_id)
203
+ message_id.subfield_id = message.ident
204
+ pid = ParagraphId(
205
+ field_id=message_id, paragraph_start=0, paragraph_end=len(message.text)
206
+ ).full()
207
+ context[pid] = message.text
270
208
 
271
209
 
272
210
  async def full_resource_prompt_context(
273
211
  context: CappedPromptContext,
274
212
  kbid: str,
275
213
  ordered_paragraphs: list[FindParagraph],
276
- resource: Optional[str],
214
+ rid: str | None,
277
215
  strategy: FullResourceStrategy,
278
216
  metrics: Metrics,
279
217
  augmented_context: AugmentedContext,
@@ -288,16 +226,16 @@ async def full_resource_prompt_context(
288
226
  ordered_paragraphs: The results of the retrieval (find) operation.
289
227
  resource: The resource to be included in the context. This is used only when chatting with a specific resource with no retrieval.
290
228
  strategy: strategy instance containing, for example, the number of full resources to include in the context.
291
- """ # noqa: E501
292
- if resource is not None:
229
+ """
230
+ if rid is not None:
293
231
  # The user has specified a resource to be included in the context.
294
- ordered_resources = [resource]
232
+ ordered_resources = [rid]
295
233
  else:
296
234
  # Collect the list of resources in the results (in order of relevance).
297
235
  ordered_resources = []
298
236
  for paragraph in ordered_paragraphs:
299
- resource_uuid = parse_text_block_id(paragraph.id).rid
300
- if resource_uuid not in ordered_resources:
237
+ rid = parse_text_block_id(paragraph.id).rid
238
+ if rid not in ordered_resources:
301
239
  skip = False
302
240
  if strategy.apply_to is not None:
303
241
  # decide whether the resource should be extended or not
@@ -307,35 +245,62 @@ async def full_resource_prompt_context(
307
245
  )
308
246
 
309
247
  if not skip:
310
- ordered_resources.append(resource_uuid)
311
-
312
- # For each resource, collect the extracted text from all its fields.
313
- resources_extracted_texts = await run_concurrently(
314
- [
315
- hydrate_resource_text(kbid, resource_uuid, max_concurrent_tasks=MAX_RESOURCE_FIELD_TASKS)
316
- for resource_uuid in ordered_resources[: strategy.count]
317
- ],
318
- max_concurrent=MAX_RESOURCE_TASKS,
248
+ ordered_resources.append(rid)
249
+ # skip when we have enough resource ids
250
+ if strategy.count is not None and len(ordered_resources) > strategy.count:
251
+ break
252
+
253
+ ordered_resources = ordered_resources[: strategy.count]
254
+
255
+ # For each resource, collect the extracted text from all its fields and
256
+ # include the title and summary as well
257
+ augmented = await rpc.augment(
258
+ kbid,
259
+ AugmentRequest(
260
+ resources=[
261
+ AugmentResources(
262
+ given=ordered_resources,
263
+ title=True,
264
+ summary=True,
265
+ fields=AugmentResourceFields(
266
+ text=True,
267
+ filters=[],
268
+ ),
269
+ )
270
+ ]
271
+ ),
319
272
  )
273
+
274
+ extracted_texts = {}
275
+ for rid, resource in augmented.resources.items():
276
+ if resource.title is not None:
277
+ field_id = FieldId(rid=rid, type="a", key="title").full()
278
+ extracted_texts[field_id] = resource.title
279
+ if resource.summary is not None:
280
+ field_id = FieldId(rid=rid, type="a", key="summary").full()
281
+ extracted_texts[field_id] = resource.summary
282
+
283
+ for field_id, field in augmented.fields.items():
284
+ field = cast(AugmentedField, field)
285
+ if field.text is not None:
286
+ extracted_texts[field_id] = field.text
287
+
320
288
  added_fields = set()
321
- for resource_extracted_texts in resources_extracted_texts:
322
- if resource_extracted_texts is None:
323
- continue
324
- for field, extracted_text in resource_extracted_texts:
325
- # First off, remove the text block ids from paragraphs that belong to
326
- # the same field, as otherwise the context will be duplicated.
327
- for tb_id in context.text_block_ids():
328
- if tb_id.startswith(field.full()):
329
- del context[tb_id]
330
- # Add the extracted text of each field to the context.
331
- context[field.full()] = extracted_text
332
- augmented_context.fields[field.full()] = AugmentedTextBlock(
333
- id=field.full(),
334
- text=extracted_text,
335
- augmentation_type=TextBlockAugmentationType.FULL_RESOURCE,
336
- )
289
+ for field_id, extracted_text in extracted_texts.items():
290
+ # First off, remove the text block ids from paragraphs that belong to
291
+ # the same field, as otherwise the context will be duplicated.
292
+ for tb_id in context.text_block_ids():
293
+ if tb_id.startswith(field_id):
294
+ del context[tb_id]
295
+ # Add the extracted text of each field to the context.
296
+ context[field_id] = extracted_text
297
+ augmented_context.fields[field_id] = AugmentedTextBlock(
298
+ id=field_id,
299
+ text=extracted_text,
300
+ augmentation_type=TextBlockAugmentationType.FULL_RESOURCE,
301
+ )
337
302
 
338
- added_fields.add(field.full())
303
+ added_fields.add(field_id)
339
304
 
340
305
  metrics.set("full_resource_ops", len(added_fields))
341
306
 
@@ -353,213 +318,167 @@ async def extend_prompt_context_with_metadata(
353
318
  metrics: Metrics,
354
319
  augmented_context: AugmentedContext,
355
320
  ) -> None:
321
+ rids: list[str] = []
322
+ field_ids: list[str] = []
356
323
  text_block_ids: list[TextBlockId] = []
357
324
  for text_block_id in context.text_block_ids():
358
325
  try:
359
- text_block_ids.append(parse_text_block_id(text_block_id))
326
+ tb_id = parse_text_block_id(text_block_id)
360
327
  except ValueError: # pragma: no cover
361
328
  # Some text block ids are not paragraphs nor fields, so they are skipped
362
329
  # (e.g. USER_CONTEXT_0, when the user provides extra context)
363
330
  continue
331
+
332
+ field_id = tb_id if isinstance(tb_id, FieldId) else tb_id.field_id
333
+
334
+ text_block_ids.append(tb_id)
335
+ field_ids.append(field_id.full())
336
+ rids.append(tb_id.rid)
337
+
364
338
  if len(text_block_ids) == 0: # pragma: no cover
365
339
  return
366
340
 
341
+ resource_origin = False
342
+ resource_extra = False
343
+ classification_labels = False
344
+ field_entities = False
345
+
367
346
  ops = 0
368
347
  if MetadataExtensionType.ORIGIN in strategy.types:
369
348
  ops += 1
370
- await extend_prompt_context_with_origin_metadata(
371
- context, kbid, text_block_ids, augmented_context
372
- )
349
+ resource_origin = True
373
350
 
374
351
  if MetadataExtensionType.CLASSIFICATION_LABELS in strategy.types:
375
352
  ops += 1
376
- await extend_prompt_context_with_classification_labels(
377
- context, kbid, text_block_ids, augmented_context
378
- )
353
+ classification_labels = True
379
354
 
380
355
  if MetadataExtensionType.NERS in strategy.types:
381
356
  ops += 1
382
- await extend_prompt_context_with_ner(context, kbid, text_block_ids, augmented_context)
357
+ field_entities = True
383
358
 
384
359
  if MetadataExtensionType.EXTRA_METADATA in strategy.types:
385
360
  ops += 1
386
- await extend_prompt_context_with_extra_metadata(context, kbid, text_block_ids, augmented_context)
361
+ resource_extra = True
387
362
 
388
363
  metrics.set("metadata_extension_ops", ops * len(text_block_ids))
389
364
 
365
+ augment_req = AugmentRequest()
366
+ if resource_origin or resource_extra or classification_labels:
367
+ augment_req.resources = [
368
+ AugmentResources(
369
+ given=rids,
370
+ origin=resource_origin,
371
+ extra=resource_extra,
372
+ classification_labels=classification_labels,
373
+ )
374
+ ]
375
+ if classification_labels or field_entities:
376
+ augment_req.fields = [
377
+ AugmentFields(
378
+ given=field_ids,
379
+ classification_labels=classification_labels,
380
+ entities=field_entities,
381
+ )
382
+ ]
390
383
 
391
- def parse_text_block_id(text_block_id: str) -> TextBlockId:
392
- try:
393
- # Typically, the text block id is a paragraph id
394
- return ParagraphId.from_string(text_block_id)
395
- except ValueError:
396
- # When we're doing `full_resource` or `hierarchy` strategies,the text block id
397
- # is a field id
398
- return FieldId.from_string(text_block_id)
384
+ if augment_req.resources is None and augment_req.fields is None:
385
+ # nothing to augment
386
+ return
399
387
 
388
+ augmented = await rpc.augment(kbid, augment_req)
400
389
 
401
- async def extend_prompt_context_with_origin_metadata(
402
- context: CappedPromptContext,
403
- kbid,
404
- text_block_ids: list[TextBlockId],
405
- augmented_context: AugmentedContext,
406
- ):
407
- async def _get_origin(kbid: str, rid: str) -> tuple[str, Optional[Origin]]:
408
- origin = None
409
- resource = await cache.get_resource(kbid, rid)
410
- if resource is not None:
411
- pb_origin = await resource.get_origin()
412
- if pb_origin is not None:
413
- origin = from_proto.origin(pb_origin)
414
- return rid, origin
415
-
416
- rids = {tb_id.rid for tb_id in text_block_ids}
417
- origins = await run_concurrently([_get_origin(kbid, rid) for rid in rids])
418
- rid_to_origin = {rid: origin for rid, origin in origins if origin is not None}
419
390
  for tb_id in text_block_ids:
420
- origin = rid_to_origin.get(tb_id.rid)
421
- if origin is not None and tb_id.full() in context:
422
- text = context.output.pop(tb_id.full())
423
- extended_text = text + f"\n\nDOCUMENT METADATA AT ORIGIN:\n{to_yaml(origin)}"
424
- context[tb_id.full()] = extended_text
425
- augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
426
- id=tb_id.full(),
427
- text=extended_text,
428
- parent=tb_id.full(),
429
- augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
430
- )
391
+ field_id = tb_id if isinstance(tb_id, FieldId) else tb_id.field_id
431
392
 
393
+ resource = augmented.resources.get(tb_id.rid)
394
+ field = augmented.fields.get(field_id.full())
432
395
 
433
- async def extend_prompt_context_with_classification_labels(
434
- context: CappedPromptContext,
435
- kbid: str,
436
- text_block_ids: list[TextBlockId],
437
- augmented_context: AugmentedContext,
438
- ):
439
- async def _get_labels(kbid: str, _id: TextBlockId) -> tuple[TextBlockId, list[tuple[str, str]]]:
440
- fid = _id if isinstance(_id, FieldId) else _id.field_id
441
- labels = set()
442
- resource = await cache.get_resource(kbid, fid.rid)
443
396
  if resource is not None:
444
- pb_basic = await resource.get_basic()
445
- if pb_basic is not None:
446
- # Add the classification labels of the resource
447
- for classif in pb_basic.usermetadata.classifications:
448
- labels.add((classif.labelset, classif.label))
449
- # Add the classifications labels of the field
450
- for fc in pb_basic.computedmetadata.field_classifications:
451
- if fc.field.field == fid.key and fc.field.field_type == fid.pb_type:
452
- for classif in fc.classifications:
453
- if classif.cancelled_by_user: # pragma: no cover
454
- continue
455
- labels.add((classif.labelset, classif.label))
456
- return _id, list(labels)
457
-
458
- classif_labels = await run_concurrently([_get_labels(kbid, tb_id) for tb_id in text_block_ids])
459
- tb_id_to_labels = {tb_id: labels for tb_id, labels in classif_labels if len(labels) > 0}
460
- for tb_id in text_block_ids:
461
- labels = tb_id_to_labels.get(tb_id)
462
- if labels is not None and tb_id.full() in context:
463
- text = context.output.pop(tb_id.full())
464
-
465
- labels_text = "DOCUMENT CLASSIFICATION LABELS:"
466
- for labelset, label in labels:
467
- labels_text += f"\n - {label} ({labelset})"
468
- extended_text = text + "\n\n" + labels_text
469
-
470
- context[tb_id.full()] = extended_text
471
- augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
472
- id=tb_id.full(),
473
- text=extended_text,
474
- parent=tb_id.full(),
475
- augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
476
- )
397
+ if resource.origin is not None:
398
+ text = context.output.pop(tb_id.full())
399
+ extended_text = text + f"\n\nDOCUMENT METADATA AT ORIGIN:\n{to_yaml(resource.origin)}"
400
+ context[tb_id.full()] = extended_text
401
+ augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
402
+ id=tb_id.full(),
403
+ text=extended_text,
404
+ parent=tb_id.full(),
405
+ augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
406
+ )
477
407
 
408
+ if resource.extra is not None:
409
+ text = context.output.pop(tb_id.full())
410
+ extended_text = text + f"\n\nDOCUMENT EXTRA METADATA:\n{to_yaml(resource.extra)}"
411
+ context[tb_id.full()] = extended_text
412
+ augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
413
+ id=tb_id.full(),
414
+ text=extended_text,
415
+ parent=tb_id.full(),
416
+ augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
417
+ )
478
418
 
479
- async def extend_prompt_context_with_ner(
480
- context: CappedPromptContext,
481
- kbid: str,
482
- text_block_ids: list[TextBlockId],
483
- augmented_context: AugmentedContext,
484
- ):
485
- async def _get_ners(kbid: str, _id: TextBlockId) -> tuple[TextBlockId, dict[str, set[str]]]:
486
- fid = _id if isinstance(_id, FieldId) else _id.field_id
487
- ners: dict[str, set[str]] = {}
488
- resource = await cache.get_resource(kbid, fid.rid)
489
- if resource is not None:
490
- field = await resource.get_field(fid.key, fid.pb_type, load=False)
491
- fcm = await field.get_field_metadata()
492
- if fcm is not None:
493
- # Data Augmentation + Processor entities
494
- for (
495
- data_aumgentation_task_id,
496
- entities_wrapper,
497
- ) in fcm.metadata.entities.items():
498
- for entity in entities_wrapper.entities:
499
- ners.setdefault(entity.label, set()).add(entity.text)
500
- # Legacy processor entities
501
- # TODO: Remove once processor doesn't use this anymore and remove the positions and ner fields from the message
502
- for token, family in fcm.metadata.ner.items():
503
- ners.setdefault(family, set()).add(token)
504
- return _id, ners
505
-
506
- nerss = await run_concurrently([_get_ners(kbid, tb_id) for tb_id in text_block_ids])
507
- tb_id_to_ners = {tb_id: ners for tb_id, ners in nerss if len(ners) > 0}
508
- for tb_id in text_block_ids:
509
- ners = tb_id_to_ners.get(tb_id)
510
- if ners is not None and tb_id.full() in context:
511
- text = context.output.pop(tb_id.full())
512
-
513
- ners_text = "DOCUMENT NAMED ENTITIES (NERs):"
514
- for family, tokens in ners.items():
515
- ners_text += f"\n - {family}:"
516
- for token in sorted(list(tokens)):
517
- ners_text += f"\n - {token}"
518
-
519
- extended_text = text + "\n\n" + ners_text
520
-
521
- context[tb_id.full()] = extended_text
522
- augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
523
- id=tb_id.full(),
524
- text=extended_text,
525
- parent=tb_id.full(),
526
- augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
527
- )
419
+ if tb_id.full() in context:
420
+ if (resource is not None and resource.classification_labels) or (
421
+ field is not None and field.classification_labels
422
+ ):
423
+ text = context.output.pop(tb_id.full())
424
+
425
+ labels_text = "DOCUMENT CLASSIFICATION LABELS:"
426
+ if resource is not None and resource.classification_labels:
427
+ for labelset, labels in resource.classification_labels.items():
428
+ for label in labels:
429
+ labels_text += f"\n - {label} ({labelset})"
430
+
431
+ if field is not None and field.classification_labels:
432
+ for labelset, labels in field.classification_labels.items():
433
+ for label in labels:
434
+ labels_text += f"\n - {label} ({labelset})"
435
+
436
+ extended_text = text + "\n\n" + labels_text
437
+
438
+ context[tb_id.full()] = extended_text
439
+ augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
440
+ id=tb_id.full(),
441
+ text=extended_text,
442
+ parent=tb_id.full(),
443
+ augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
444
+ )
528
445
 
446
+ if field is not None and field.entities:
447
+ ners = field.entities
529
448
 
530
- async def extend_prompt_context_with_extra_metadata(
531
- context: CappedPromptContext,
532
- kbid: str,
533
- text_block_ids: list[TextBlockId],
534
- augmented_context: AugmentedContext,
535
- ):
536
- async def _get_extra(kbid: str, rid: str) -> tuple[str, Optional[Extra]]:
537
- extra = None
538
- resource = await cache.get_resource(kbid, rid)
539
- if resource is not None:
540
- pb_extra = await resource.get_extra()
541
- if pb_extra is not None:
542
- extra = from_proto.extra(pb_extra)
543
- return rid, extra
544
-
545
- rids = {tb_id.rid for tb_id in text_block_ids}
546
- extras = await run_concurrently([_get_extra(kbid, rid) for rid in rids])
547
- rid_to_extra = {rid: extra for rid, extra in extras if extra is not None}
548
- for tb_id in text_block_ids:
549
- extra = rid_to_extra.get(tb_id.rid)
550
- if extra is not None and tb_id.full() in context:
551
- text = context.output.pop(tb_id.full())
552
- extended_text = text + f"\n\nDOCUMENT EXTRA METADATA:\n{to_yaml(extra)}"
553
- context[tb_id.full()] = extended_text
554
- augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
555
- id=tb_id.full(),
556
- text=extended_text,
557
- parent=tb_id.full(),
558
- augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
559
- )
449
+ text = context.output.pop(tb_id.full())
450
+
451
+ ners_text = "DOCUMENT NAMED ENTITIES (NERs):"
452
+ for family, tokens in ners.items():
453
+ ners_text += f"\n - {family}:"
454
+ for token in sorted(list(tokens)):
455
+ ners_text += f"\n - {token}"
456
+
457
+ extended_text = text + "\n\n" + ners_text
458
+
459
+ context[tb_id.full()] = extended_text
460
+ augmented_context.paragraphs[tb_id.full()] = AugmentedTextBlock(
461
+ id=tb_id.full(),
462
+ text=extended_text,
463
+ parent=tb_id.full(),
464
+ augmentation_type=TextBlockAugmentationType.METADATA_EXTENSION,
465
+ )
466
+
467
+
468
+ def parse_text_block_id(text_block_id: str) -> TextBlockId:
469
+ try:
470
+ # Typically, the text block id is a paragraph id
471
+ return ParagraphId.from_string(text_block_id)
472
+ except ValueError:
473
+ # When we're doing `full_resource` or `hierarchy` strategies,the text block id
474
+ # is a field id
475
+ return FieldId.from_string(text_block_id)
560
476
 
561
477
 
562
478
  def to_yaml(obj: BaseModel) -> str:
479
+ # FIXME: this dumps enums REALLY poorly, e.g.,
480
+ # `!!python/object/apply:nucliadb_models.metadata.Source\n- WEB` for
481
+ # Source.WEB instead of `WEB`
563
482
  return yaml.dump(
564
483
  obj.model_dump(exclude_none=True, exclude_defaults=True, exclude_unset=True),
565
484
  default_flow_style=False,
@@ -589,37 +508,74 @@ async def field_extension_prompt_context(
589
508
  if resource_uuid not in ordered_resources:
590
509
  ordered_resources.append(resource_uuid)
591
510
 
592
- # Fetch the extracted texts of the specified fields for each resource
593
- extend_fields = strategy.fields
594
- extend_field_ids = []
595
- for resource_uuid in ordered_resources:
596
- for field_id in extend_fields:
597
- try:
598
- fid = FieldId.from_string(f"{resource_uuid}/{field_id.strip('/')}")
599
- extend_field_ids.append(fid)
600
- except ValueError: # pragma: no cover
601
- # Invalid field id, skiping
602
- continue
511
+ resource_title = False
512
+ resource_summary = False
513
+ filters: list[nucliadb_models.filters.Field | nucliadb_models.filters.Generated] = []
514
+ # this strategy exposes a way to access resource title and summary using a
515
+ # field id. However, as they are resource properties, we must request it as
516
+ # that
517
+ for name in strategy.fields:
518
+ if name == "a/title":
519
+ resource_title = True
520
+ elif name == "a/summary":
521
+ resource_summary = True
522
+ else:
523
+ # model already enforces type/name format
524
+ field_type, field_name = name.split("/")
525
+ filters.append(
526
+ nucliadb_models.filters.Field(
527
+ type=FIELD_TYPE_STR_TO_NAME[field_type], name=field_name or None
528
+ )
529
+ )
603
530
 
604
- tasks = [hydrate_field_text(kbid, fid) for fid in extend_field_ids]
605
- field_extracted_texts = await run_concurrently(tasks)
531
+ for da_prefix in strategy.data_augmentation_field_prefixes:
532
+ filters.append(nucliadb_models.filters.Generated(by="data-augmentation", da_task=da_prefix))
533
+
534
+ augmented = await rpc.augment(
535
+ kbid,
536
+ AugmentRequest(
537
+ resources=[
538
+ AugmentResources(
539
+ given=ordered_resources,
540
+ title=resource_title,
541
+ summary=resource_summary,
542
+ fields=AugmentResourceFields(
543
+ text=True,
544
+ filters=filters,
545
+ ),
546
+ )
547
+ ]
548
+ ),
549
+ )
550
+
551
+ # REVIEW(decoupled-ask): we don't have the field count anymore, is this good enough?
552
+ metrics.set("field_extension_ops", len(ordered_resources))
606
553
 
607
- metrics.set("field_extension_ops", len(field_extracted_texts))
554
+ extracted_texts = {}
555
+ # now we need to expose title and summary as fields again, so it gets
556
+ # consistent with the view we are providing in the API
557
+ for rid, augmented_resource in augmented.resources.items():
558
+ if augmented_resource.title:
559
+ extracted_texts[f"{rid}/a/title"] = augmented_resource.title
560
+ if augmented_resource.summary:
561
+ extracted_texts[f"{rid}/a/summary"] = augmented_resource.summary
608
562
 
609
- for result in field_extracted_texts:
610
- if result is None: # pragma: no cover
563
+ for fid, augmented_field in augmented.fields.items():
564
+ if augmented_field is None or augmented_field.text is None: # pragma: no cover
611
565
  continue
612
- field, extracted_text = result
566
+ extracted_texts[fid] = augmented_field.text
567
+
568
+ for fid, extracted_text in extracted_texts.items():
613
569
  # First off, remove the text block ids from paragraphs that belong to
614
570
  # the same field, as otherwise the context will be duplicated.
615
571
  for tb_id in context.text_block_ids():
616
- if tb_id.startswith(field.full()):
572
+ if tb_id.startswith(fid):
617
573
  del context[tb_id]
618
574
  # Add the extracted text of each field to the beginning of the context.
619
- if field.full() not in context:
620
- context[field.full()] = extracted_text
621
- augmented_context.fields[field.full()] = AugmentedTextBlock(
622
- id=field.full(),
575
+ if fid not in context:
576
+ context[fid] = extracted_text
577
+ augmented_context.fields[fid] = AugmentedTextBlock(
578
+ id=fid,
623
579
  text=extracted_text,
624
580
  augmentation_type=TextBlockAugmentationType.FIELD_EXTENSION,
625
581
  )
@@ -630,13 +586,6 @@ async def field_extension_prompt_context(
630
586
  context[paragraph.id] = _clean_paragraph_text(paragraph)
631
587
 
632
588
 
633
- async def get_orm_field(kbid: str, field_id: FieldId) -> Optional[Field]:
634
- resource = await cache.get_resource(kbid, field_id.rid)
635
- if resource is None: # pragma: no cover
636
- return None
637
- return await resource.get_field(key=field_id.key, type=field_id.pb_type, load=False)
638
-
639
-
640
589
  async def neighbouring_paragraphs_prompt_context(
641
590
  context: CappedPromptContext,
642
591
  kbid: str,
@@ -652,83 +601,52 @@ async def neighbouring_paragraphs_prompt_context(
652
601
  retrieved_paragraphs_ids = [
653
602
  ParagraphId.from_string(text_block.id) for text_block in ordered_text_blocks
654
603
  ]
655
- unique_field_ids = list({pid.field_id for pid in retrieved_paragraphs_ids})
656
-
657
- # Get extracted texts and metadatas for all fields
658
- fm_ops = []
659
- et_ops = []
660
- for field_id in unique_field_ids:
661
- field = await get_orm_field(kbid, field_id)
662
- if field is None:
663
- continue
664
- fm_ops.append(asyncio.create_task(field.get_field_metadata()))
665
- et_ops.append(asyncio.create_task(field.get_extracted_text()))
666
-
667
- field_metadatas: dict[FieldId, FieldComputedMetadata] = {
668
- fid: fm for fid, fm in zip(unique_field_ids, await asyncio.gather(*fm_ops)) if fm is not None
669
- }
670
- extracted_texts: dict[FieldId, ExtractedText] = {
671
- fid: et for fid, et in zip(unique_field_ids, await asyncio.gather(*et_ops)) if et is not None
672
- }
673
-
674
- def _get_paragraph_text(extracted_text: ExtractedText, pid: ParagraphId) -> str:
675
- if pid.field_id.subfield_id:
676
- text = extracted_text.split_text.get(pid.field_id.subfield_id) or ""
677
- else:
678
- text = extracted_text.text
679
- return text[pid.paragraph_start : pid.paragraph_end]
604
+
605
+ augmented = await rpc.augment(
606
+ kbid,
607
+ AugmentRequest(
608
+ paragraphs=[
609
+ AugmentParagraphs(
610
+ given=[AugmentParagraph(id=pid.full()) for pid in retrieved_paragraphs_ids],
611
+ text=True,
612
+ neighbours_before=strategy.before,
613
+ neighbours_after=strategy.after,
614
+ )
615
+ ]
616
+ ),
617
+ )
680
618
 
681
619
  for pid in retrieved_paragraphs_ids:
682
- # Add the retrieved paragraph first
683
- field_extracted_text = extracted_texts.get(pid.field_id, None)
684
- if field_extracted_text is None:
620
+ paragraph = augmented.paragraphs.get(pid.full())
621
+ if paragraph is None:
685
622
  continue
686
- ptext = _get_paragraph_text(field_extracted_text, pid)
623
+
624
+ ptext = paragraph.text or ""
687
625
  if ptext and pid.full() not in context:
688
626
  context[pid.full()] = ptext
689
627
 
690
628
  # Now add the neighbouring paragraphs
691
- field_extracted_metadata = field_metadatas.get(pid.field_id, None)
692
- if field_extracted_metadata is None:
693
- continue
694
-
695
- field_pids = [
696
- ParagraphId(
697
- field_id=pid.field_id,
698
- paragraph_start=p.start,
699
- paragraph_end=p.end,
700
- )
701
- for p in field_extracted_metadata.metadata.paragraphs
629
+ neighbour_ids = [
630
+ *(paragraph.neighbours_before or []),
631
+ *(paragraph.neighbours_after or []),
702
632
  ]
703
- try:
704
- index = field_pids.index(pid)
705
- except ValueError:
706
- continue
633
+ for npid in neighbour_ids:
634
+ neighbour = augmented.paragraphs.get(npid)
635
+ assert neighbour is not None, "augment should never return dangling paragraph references"
707
636
 
708
- for neighbour_index in get_neighbouring_indices(
709
- index=index,
710
- before=strategy.before,
711
- after=strategy.after,
712
- field_pids=field_pids,
713
- ):
714
- if neighbour_index == index:
715
- # Already handled above
637
+ if ParagraphId.from_string(npid) in retrieved_paragraphs_ids or npid in context:
638
+ # already added
716
639
  continue
717
- try:
718
- npid = field_pids[neighbour_index]
719
- except IndexError:
720
- continue
721
- if npid in retrieved_paragraphs_ids or npid.full() in context:
722
- # Already added
723
- continue
724
- ptext = _get_paragraph_text(field_extracted_text, npid)
725
- if not ptext:
640
+
641
+ ntext = neighbour.text
642
+ if not ntext:
726
643
  continue
727
- context[npid.full()] = ptext
728
- augmented_context.paragraphs[npid.full()] = AugmentedTextBlock(
729
- id=npid.full(),
730
- text=ptext,
731
- position=get_text_position(npid, neighbour_index, field_extracted_metadata),
644
+
645
+ context[npid] = ntext
646
+ augmented_context.paragraphs[npid] = AugmentedTextBlock(
647
+ id=npid,
648
+ text=ntext,
649
+ position=neighbour.position,
732
650
  parent=pid.full(),
733
651
  augmentation_type=TextBlockAugmentationType.NEIGHBOURING_PARAGRAPHS,
734
652
  )
@@ -738,7 +656,7 @@ async def neighbouring_paragraphs_prompt_context(
738
656
 
739
657
  def get_text_position(
740
658
  paragraph_id: ParagraphId, index: int, field_metadata: FieldComputedMetadata
741
- ) -> Optional[TextPosition]:
659
+ ) -> TextPosition | None:
742
660
  if paragraph_id.field_id.subfield_id:
743
661
  metadata = field_metadata.split_metadata[paragraph_id.field_id.subfield_id]
744
662
  else:
@@ -777,148 +695,144 @@ async def conversation_prompt_context(
777
695
  metrics: Metrics,
778
696
  augmented_context: AugmentedContext,
779
697
  ):
780
- analyzed_fields: List[str] = []
698
+ analyzed_fields: list[str] = []
781
699
  ops = 0
782
- async with get_driver().ro_transaction() as txn:
783
- storage = await get_storage()
784
- kb = KnowledgeBoxORM(txn, storage, kbid)
785
- for paragraph in ordered_paragraphs:
786
- if paragraph.id not in context:
787
- context[paragraph.id] = _clean_paragraph_text(paragraph)
788
700
 
789
- # If the paragraph is a conversation and it matches semantically, we assume we
790
- # have matched with the question, therefore try to include the answer to the
791
- # context by pulling the next few messages of the conversation field
792
- rid, field_type, field_id, mident = paragraph.id.split("/")[:4]
793
- if field_type == "c" and paragraph.score_type in (
794
- SCORE_TYPE.VECTOR,
795
- SCORE_TYPE.BOTH,
796
- SCORE_TYPE.BM25,
797
- ):
798
- field_unique_id = "-".join([rid, field_type, field_id])
799
- if field_unique_id in analyzed_fields:
701
+ conversation_paragraphs = []
702
+ for paragraph in ordered_paragraphs:
703
+ if paragraph.id not in context:
704
+ context[paragraph.id] = _clean_paragraph_text(paragraph)
705
+
706
+ parent_paragraph_id = ParagraphId.from_string(paragraph.id)
707
+
708
+ if parent_paragraph_id.field_id.type != FieldTypeName.CONVERSATION.abbreviation():
709
+ # conversational strategy only applies to conversation fields
710
+ continue
711
+
712
+ field_unique_id = parent_paragraph_id.field_id.full_without_subfield()
713
+ if field_unique_id in analyzed_fields:
714
+ continue
715
+
716
+ conversation_paragraphs.append((parent_paragraph_id, paragraph))
717
+
718
+ # augment conversation paragraphs
719
+
720
+ if strategy.full:
721
+ full_conversation = True
722
+ max_conversation_messages = None
723
+ else:
724
+ full_conversation = False
725
+ max_conversation_messages = strategy.max_messages
726
+
727
+ augment = AugmentRequest(
728
+ fields=[
729
+ AugmentFields(
730
+ given=[paragraph_id.field_id.full() for paragraph_id, _ in conversation_paragraphs],
731
+ full_conversation=full_conversation,
732
+ max_conversation_messages=max_conversation_messages,
733
+ conversation_text_attachments=strategy.attachments_text,
734
+ conversation_image_attachments=strategy.attachments_images,
735
+ )
736
+ ]
737
+ )
738
+ augmented = await rpc.augment(kbid, augment)
739
+
740
+ attachments: dict[ParagraphId, list[FieldId]] = {}
741
+ for parent_paragraph_id, paragraph in conversation_paragraphs:
742
+ fid = parent_paragraph_id.field_id
743
+ field = augmented.fields.get(fid.full_without_subfield())
744
+ if field is not None:
745
+ field = cast(AugmentedConversationField, field)
746
+ for _message in field.messages or []:
747
+ ops += 1
748
+ if not _message.text:
749
+ continue
750
+
751
+ text = _message.text
752
+ pid = ParagraphId(
753
+ field_id=FieldId(
754
+ rid=fid.rid,
755
+ type=fid.type,
756
+ key=fid.key,
757
+ subfield_id=_message.ident,
758
+ ),
759
+ paragraph_start=0,
760
+ paragraph_end=len(text),
761
+ ).full()
762
+ if pid in context:
800
763
  continue
801
- resource = await kb.get(rid)
802
- if resource is None: # pragma: no cover
764
+ context[pid] = text
765
+
766
+ attachments.setdefault(parent_paragraph_id, []).extend(
767
+ [FieldId.from_string(attachment_id) for attachment_id in field.attachments or []]
768
+ )
769
+ augmented_context.paragraphs[pid] = AugmentedTextBlock(
770
+ id=pid,
771
+ text=text,
772
+ parent=paragraph.id,
773
+ augmentation_type=TextBlockAugmentationType.CONVERSATION,
774
+ )
775
+
776
+ # augment attachments
777
+
778
+ if strategy.attachments_text or (
779
+ (strategy.attachments_images and visual_llm) and len(attachments) > 0
780
+ ):
781
+ augment = AugmentRequest(
782
+ fields=[
783
+ AugmentFields(
784
+ given=[
785
+ id.full()
786
+ for paragraph_attachments in attachments.values()
787
+ for id in paragraph_attachments
788
+ ],
789
+ text=strategy.attachments_text,
790
+ file_thumbnail=(strategy.attachments_images and visual_llm),
791
+ )
792
+ ]
793
+ )
794
+ augmented = await rpc.augment(kbid, augment)
795
+
796
+ for parent_paragraph_id, paragraph_attachments in attachments.items():
797
+ for attachment_id in paragraph_attachments:
798
+ attachment_field = augmented.fields.get(attachment_id.full())
799
+
800
+ if attachment_field is None:
803
801
  continue
804
802
 
805
- field_obj: Conversation = await resource.get_field(
806
- field_id, FIELD_TYPE_STR_TO_PB["c"], load=True
807
- ) # type: ignore
808
- cmetadata = await field_obj.get_metadata()
809
-
810
- attachments: List[resources_pb2.FieldRef] = []
811
- if strategy.full:
812
- ops += 5
813
- extracted_text = await field_obj.get_extracted_text()
814
- for current_page in range(1, cmetadata.pages + 1):
815
- conv = await field_obj.db_get_value(current_page)
816
-
817
- for message in conv.messages:
818
- ident = message.ident
819
- if extracted_text is not None:
820
- text = extracted_text.split_text.get(ident, message.content.text.strip())
821
- else:
822
- text = message.content.text.strip()
823
- pid = f"{rid}/{field_type}/{field_id}/{ident}/0-{len(text) + 1}"
824
- attachments.extend(message.content.attachments_fields)
825
- if pid in context:
826
- continue
827
- context[pid] = text
828
- augmented_context.paragraphs[pid] = AugmentedTextBlock(
829
- id=pid,
830
- text=text,
831
- parent=paragraph.id,
832
- augmentation_type=TextBlockAugmentationType.CONVERSATION,
833
- )
834
- else:
835
- # Add first message
836
- extracted_text = await field_obj.get_extracted_text()
837
- first_page = await field_obj.db_get_value()
838
- if len(first_page.messages) > 0:
839
- message = first_page.messages[0]
840
- ident = message.ident
841
- if extracted_text is not None:
842
- text = extracted_text.split_text.get(ident, message.content.text.strip())
843
- else:
844
- text = message.content.text.strip()
845
- attachments.extend(message.content.attachments_fields)
846
- pid = f"{rid}/{field_type}/{field_id}/{ident}/0-{len(text) + 1}"
847
- if pid in context:
848
- continue
849
- context[pid] = text
850
- augmented_context.paragraphs[pid] = AugmentedTextBlock(
851
- id=pid,
852
- text=text,
853
- parent=paragraph.id,
854
- augmentation_type=TextBlockAugmentationType.CONVERSATION,
855
- )
803
+ if strategy.attachments_text and attachment_field.text:
804
+ ops += 1
856
805
 
857
- messages: Deque[resources_pb2.Message] = deque(maxlen=strategy.max_messages)
858
-
859
- pending = -1
860
- for page in range(1, cmetadata.pages + 1):
861
- # Collect the messages with the window asked by the user arround the match paragraph
862
- conv = await field_obj.db_get_value(page)
863
- for message in conv.messages:
864
- messages.append(message)
865
- if pending > 0:
866
- pending -= 1
867
- if message.ident == mident:
868
- pending = (strategy.max_messages - 1) // 2
869
- if pending == 0:
870
- break
871
- if pending == 0:
872
- break
873
-
874
- for message in messages:
875
- ops += 1
876
- text = message.content.text.strip()
877
- attachments.extend(message.content.attachments_fields)
878
- pid = f"{rid}/{field_type}/{field_id}/{message.ident}/0-{len(message.content.text) + 1}"
879
- if pid in context:
880
- continue
806
+ pid = f"{attachment_id.full_without_subfield()}/0-{len(attachment_field.text)}"
807
+ if pid not in context:
808
+ text = f"Attachment {attachment_id.key}: {attachment_field.text}\n\n"
881
809
  context[pid] = text
882
810
  augmented_context.paragraphs[pid] = AugmentedTextBlock(
883
811
  id=pid,
884
812
  text=text,
885
- parent=paragraph.id,
813
+ parent=parent_paragraph_id.full(),
886
814
  augmentation_type=TextBlockAugmentationType.CONVERSATION,
887
815
  )
888
816
 
889
- if strategy.attachments_text:
890
- # add on the context the images if vlm enabled
891
- for attachment in attachments:
892
- ops += 1
893
- field: File = await resource.get_field(
894
- attachment.field_id, attachment.field_type, load=True
895
- ) # type: ignore
896
- extracted_text = await field.get_extracted_text()
897
- if extracted_text is not None:
898
- pid = f"{rid}/{field_type}/{attachment.field_id}/0-{len(extracted_text.text) + 1}"
899
- if pid in context:
900
- continue
901
- text = f"Attachment {attachment.field_id}: {extracted_text.text}\n\n"
902
- context[pid] = text
903
- augmented_context.paragraphs[pid] = AugmentedTextBlock(
904
- id=pid,
905
- text=text,
906
- parent=paragraph.id,
907
- augmentation_type=TextBlockAugmentationType.CONVERSATION,
908
- )
909
-
910
- if strategy.attachments_images and visual_llm:
911
- for attachment in attachments:
912
- ops += 1
913
- file_field: File = await resource.get_field(
914
- attachment.field_id, attachment.field_type, load=True
915
- ) # type: ignore
916
- image = await get_file_thumbnail_image(file_field)
917
- if image is not None:
918
- pid = f"{rid}/f/{attachment.field_id}/0-0"
919
- context.images[pid] = image
920
-
921
- analyzed_fields.append(field_unique_id)
817
+ if (
818
+ (strategy.attachments_images and visual_llm)
819
+ and isinstance(attachment_field, AugmentedFileField)
820
+ and attachment_field.thumbnail_image
821
+ ):
822
+ ops += 1
823
+
824
+ image = await rpc.download_image(
825
+ kbid,
826
+ attachment_id,
827
+ attachment_field.thumbnail_image,
828
+ # We assume the thumbnail is always generated as JPEG by Nuclia processing
829
+ mime_type="image/jpeg",
830
+ )
831
+ if image is not None:
832
+ pid = f"{attachment_id.rid}/f/{attachment_id.key}/0-0"
833
+ context.images[pid] = image
834
+
835
+ analyzed_fields.append(field_unique_id)
922
836
  metrics.set("conversation_ops", ops)
923
837
 
924
838
 
@@ -939,66 +853,93 @@ async def hierarchy_prompt_context(
939
853
  # Make a copy of the ordered paragraphs to avoid modifying the original list, which is returned
940
854
  # in the response to the user
941
855
  ordered_paragraphs_copy = copy.deepcopy(ordered_paragraphs)
942
- resources: Dict[str, ExtraCharsParagraph] = {}
856
+ resources: dict[str, ExtraCharsParagraph] = {}
943
857
 
944
858
  # Iterate paragraphs to get extended text
859
+ paragraphs_to_augment = []
945
860
  for paragraph in ordered_paragraphs_copy:
946
861
  paragraph_id = ParagraphId.from_string(paragraph.id)
947
- extended_paragraph_text = paragraph.text
948
- if paragraphs_extra_characters > 0:
949
- extended_paragraph_text = await get_paragraph_text(
950
- kbid=kbid,
951
- paragraph_id=paragraph_id,
952
- log_on_missing_field=True,
953
- )
954
862
  rid = paragraph_id.rid
863
+
864
+ if paragraphs_extra_characters > 0:
865
+ paragraph_id.paragraph_end += paragraphs_extra_characters
866
+
867
+ paragraphs_to_augment.append(paragraph_id)
868
+
955
869
  if rid not in resources:
956
870
  # Get the title and the summary of the resource
957
- title_text = await get_paragraph_text(
958
- kbid=kbid,
959
- paragraph_id=ParagraphId(
960
- field_id=FieldId(
961
- rid=rid,
962
- type="a",
963
- key="title",
964
- ),
965
- paragraph_start=0,
966
- paragraph_end=500,
871
+ title_paragraph_id = ParagraphId(
872
+ field_id=FieldId(
873
+ rid=rid,
874
+ type="a",
875
+ key="title",
967
876
  ),
968
- log_on_missing_field=False,
877
+ paragraph_start=0,
878
+ paragraph_end=500,
969
879
  )
970
- summary_text = await get_paragraph_text(
971
- kbid=kbid,
972
- paragraph_id=ParagraphId(
973
- field_id=FieldId(
974
- rid=rid,
975
- type="a",
976
- key="summary",
977
- ),
978
- paragraph_start=0,
979
- paragraph_end=1000,
880
+ summary_paragraph_id = ParagraphId(
881
+ field_id=FieldId(
882
+ rid=rid,
883
+ type="a",
884
+ key="summary",
980
885
  ),
981
- log_on_missing_field=False,
886
+ paragraph_start=0,
887
+ paragraph_end=1000,
982
888
  )
889
+ paragraphs_to_augment.append(title_paragraph_id)
890
+ paragraphs_to_augment.append(summary_paragraph_id)
891
+
983
892
  resources[rid] = ExtraCharsParagraph(
984
- title=title_text,
985
- summary=summary_text,
986
- paragraphs=[(paragraph, extended_paragraph_text)],
893
+ title=title_paragraph_id,
894
+ summary=summary_paragraph_id,
895
+ paragraphs=[(paragraph, paragraph_id)],
987
896
  )
988
897
  else:
989
- resources[rid].paragraphs.append((paragraph, extended_paragraph_text))
898
+ resources[rid].paragraphs.append((paragraph, paragraph_id))
990
899
 
991
900
  metrics.set("hierarchy_ops", len(resources))
901
+
902
+ augmented = await rpc.augment(
903
+ kbid,
904
+ AugmentRequest(
905
+ paragraphs=[
906
+ AugmentParagraphs(
907
+ given=[
908
+ AugmentParagraph(id=paragraph_id.full())
909
+ for paragraph_id in paragraphs_to_augment
910
+ ],
911
+ text=True,
912
+ )
913
+ ]
914
+ ),
915
+ )
916
+
992
917
  augmented_paragraphs = set()
993
918
 
994
919
  # Modify the first paragraph of each resource to include the title and summary of the resource, as well as the
995
920
  # extended paragraph text of all the paragraphs in the resource.
996
921
  for values in resources.values():
997
- title_text = values.title
998
- summary_text = values.summary
922
+ augmented_title = augmented.paragraphs.get(values.title.full())
923
+ if augmented_title:
924
+ title_text = augmented_title.text or ""
925
+ else:
926
+ title_text = ""
927
+
928
+ augmented_summary = augmented.paragraphs.get(values.summary.full())
929
+ if augmented_summary:
930
+ summary_text = augmented_summary.text or ""
931
+ else:
932
+ summary_text = ""
933
+
999
934
  first_paragraph = None
1000
935
  text_with_hierarchy = ""
1001
- for paragraph, extended_paragraph_text in values.paragraphs:
936
+ for paragraph, paragraph_id in values.paragraphs:
937
+ augmented_paragraph = augmented.paragraphs.get(paragraph_id.full())
938
+ if augmented_paragraph:
939
+ extended_paragraph_text = augmented_paragraph.text or ""
940
+ else:
941
+ extended_paragraph_text = ""
942
+
1002
943
  if first_paragraph is None:
1003
944
  first_paragraph = paragraph
1004
945
  text_with_hierarchy += "\n EXTRACTED BLOCK: \n " + extended_paragraph_text + " \n\n "
@@ -1035,14 +976,14 @@ class PromptContextBuilder:
1035
976
  self,
1036
977
  kbid: str,
1037
978
  ordered_paragraphs: list[FindParagraph],
1038
- resource: Optional[str] = None,
1039
- user_context: Optional[list[str]] = None,
1040
- user_image_context: Optional[list[Image]] = None,
1041
- strategies: Optional[Sequence[RagStrategy]] = None,
1042
- image_strategies: Optional[Sequence[ImageRagStrategy]] = None,
1043
- max_context_characters: Optional[int] = None,
979
+ resource: str | None = None,
980
+ user_context: list[str] | None = None,
981
+ user_image_context: list[Image] | None = None,
982
+ strategies: Sequence[RagStrategy] | None = None,
983
+ image_strategies: Sequence[ImageRagStrategy] | None = None,
984
+ max_context_characters: int | None = None,
1044
985
  visual_llm: bool = False,
1045
- query_image: Optional[Image] = None,
986
+ query_image: Image | None = None,
1046
987
  metrics: Metrics = Metrics("prompt_context_builder"),
1047
988
  ):
1048
989
  self.kbid = kbid
@@ -1088,10 +1029,10 @@ class PromptContextBuilder:
1088
1029
  if self.image_strategies is None or len(self.image_strategies) == 0:
1089
1030
  # Nothing to do
1090
1031
  return
1091
- page_image_strategy: Optional[PageImageStrategy] = None
1032
+ page_image_strategy: PageImageStrategy | None = None
1092
1033
  max_page_images = 5
1093
- table_image_strategy: Optional[TableImageStrategy] = None
1094
- paragraph_image_strategy: Optional[ParagraphImageStrategy] = None
1034
+ table_image_strategy: TableImageStrategy | None = None
1035
+ paragraph_image_strategy: ParagraphImageStrategy | None = None
1095
1036
  for strategy in self.image_strategies:
1096
1037
  if strategy.name == ImageRagStrategyName.PAGE_IMAGE:
1097
1038
  if page_image_strategy is None:
@@ -1121,7 +1062,12 @@ class PromptContextBuilder:
1121
1062
  # page_image_id: rid/f/myfield/0
1122
1063
  page_image_id = "/".join([pid.field_id.full(), str(paragraph_page_number)])
1123
1064
  if page_image_id not in context.images:
1124
- image = await get_page_image(self.kbid, pid, paragraph_page_number)
1065
+ image = await rpc.download_image(
1066
+ self.kbid,
1067
+ pid.field_id,
1068
+ f"generated/extracted_images_{paragraph_page_number}.png",
1069
+ mime_type="image/png",
1070
+ )
1125
1071
  if image is not None:
1126
1072
  ops += 1
1127
1073
  context.images[page_image_id] = image
@@ -1141,7 +1087,9 @@ class PromptContextBuilder:
1141
1087
  if (add_table or add_paragraph) and (
1142
1088
  paragraph.reference is not None and paragraph.reference != ""
1143
1089
  ):
1144
- pimage = await get_paragraph_image(self.kbid, pid, paragraph.reference)
1090
+ pimage = await rpc.download_image(
1091
+ self.kbid, pid.field_id, f"generated/{paragraph.reference}", mime_type="image/png"
1092
+ )
1145
1093
  if pimage is not None:
1146
1094
  ops += 1
1147
1095
  context.images[paragraph.id] = pimage
@@ -1171,12 +1119,12 @@ class PromptContextBuilder:
1171
1119
  RagStrategyName.GRAPH,
1172
1120
  ]
1173
1121
 
1174
- full_resource: Optional[FullResourceStrategy] = None
1175
- hierarchy: Optional[HierarchyResourceStrategy] = None
1176
- neighbouring_paragraphs: Optional[NeighbouringParagraphsStrategy] = None
1177
- field_extension: Optional[FieldExtensionStrategy] = None
1178
- metadata_extension: Optional[MetadataExtensionStrategy] = None
1179
- conversational_strategy: Optional[ConversationalStrategy] = None
1122
+ full_resource: FullResourceStrategy | None = None
1123
+ hierarchy: HierarchyResourceStrategy | None = None
1124
+ neighbouring_paragraphs: NeighbouringParagraphsStrategy | None = None
1125
+ field_extension: FieldExtensionStrategy | None = None
1126
+ metadata_extension: MetadataExtensionStrategy | None = None
1127
+ conversational_strategy: ConversationalStrategy | None = None
1180
1128
  for strategy in self.strategies:
1181
1129
  if strategy.name == RagStrategyName.FIELD_EXTENSION:
1182
1130
  field_extension = cast(FieldExtensionStrategy, strategy)
@@ -1269,7 +1217,7 @@ class PromptContextBuilder:
1269
1217
  )
1270
1218
 
1271
1219
 
1272
- def get_paragraph_page_number(paragraph: FindParagraph) -> Optional[int]:
1220
+ def get_paragraph_page_number(paragraph: FindParagraph) -> int | None:
1273
1221
  if not paragraph.page_with_visual:
1274
1222
  return None
1275
1223
  if paragraph.position is None:
@@ -1279,9 +1227,9 @@ def get_paragraph_page_number(paragraph: FindParagraph) -> Optional[int]:
1279
1227
 
1280
1228
  @dataclass
1281
1229
  class ExtraCharsParagraph:
1282
- title: str
1283
- summary: str
1284
- paragraphs: List[Tuple[FindParagraph, str]]
1230
+ title: ParagraphId
1231
+ summary: ParagraphId
1232
+ paragraphs: list[tuple[FindParagraph, ParagraphId]]
1285
1233
 
1286
1234
 
1287
1235
  def _clean_paragraph_text(paragraph: FindParagraph) -> str: