nucliadb 6.9.1.post5192__py3-none-any.whl → 6.10.0.post5705__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. migrations/0023_backfill_pg_catalog.py +2 -2
  2. migrations/0029_backfill_field_status.py +3 -4
  3. migrations/0032_remove_old_relations.py +2 -3
  4. migrations/0038_backfill_catalog_field_labels.py +2 -2
  5. migrations/0039_backfill_converation_splits_metadata.py +2 -2
  6. migrations/0041_reindex_conversations.py +137 -0
  7. migrations/pg/0010_shards_index.py +34 -0
  8. nucliadb/search/api/v1/resource/utils.py → migrations/pg/0011_catalog_statistics.py +5 -6
  9. migrations/pg/0012_catalog_statistics_undo.py +26 -0
  10. nucliadb/backups/create.py +2 -15
  11. nucliadb/backups/restore.py +4 -15
  12. nucliadb/backups/tasks.py +4 -1
  13. nucliadb/common/back_pressure/cache.py +2 -3
  14. nucliadb/common/back_pressure/materializer.py +7 -13
  15. nucliadb/common/back_pressure/settings.py +6 -6
  16. nucliadb/common/back_pressure/utils.py +1 -0
  17. nucliadb/common/cache.py +9 -9
  18. nucliadb/common/catalog/interface.py +12 -12
  19. nucliadb/common/catalog/pg.py +41 -29
  20. nucliadb/common/catalog/utils.py +3 -3
  21. nucliadb/common/cluster/manager.py +5 -4
  22. nucliadb/common/cluster/rebalance.py +483 -114
  23. nucliadb/common/cluster/rollover.py +25 -9
  24. nucliadb/common/cluster/settings.py +3 -8
  25. nucliadb/common/cluster/utils.py +34 -8
  26. nucliadb/common/context/__init__.py +7 -8
  27. nucliadb/common/context/fastapi.py +1 -2
  28. nucliadb/common/datamanagers/__init__.py +2 -4
  29. nucliadb/common/datamanagers/atomic.py +4 -2
  30. nucliadb/common/datamanagers/cluster.py +1 -2
  31. nucliadb/common/datamanagers/fields.py +3 -4
  32. nucliadb/common/datamanagers/kb.py +6 -6
  33. nucliadb/common/datamanagers/labels.py +2 -3
  34. nucliadb/common/datamanagers/resources.py +10 -33
  35. nucliadb/common/datamanagers/rollover.py +5 -7
  36. nucliadb/common/datamanagers/search_configurations.py +1 -2
  37. nucliadb/common/datamanagers/synonyms.py +1 -2
  38. nucliadb/common/datamanagers/utils.py +4 -4
  39. nucliadb/common/datamanagers/vectorsets.py +4 -4
  40. nucliadb/common/external_index_providers/base.py +32 -5
  41. nucliadb/common/external_index_providers/manager.py +4 -5
  42. nucliadb/common/filter_expression.py +128 -40
  43. nucliadb/common/http_clients/processing.py +12 -23
  44. nucliadb/common/ids.py +6 -4
  45. nucliadb/common/locking.py +1 -2
  46. nucliadb/common/maindb/driver.py +9 -8
  47. nucliadb/common/maindb/local.py +5 -5
  48. nucliadb/common/maindb/pg.py +9 -8
  49. nucliadb/common/nidx.py +3 -4
  50. nucliadb/export_import/datamanager.py +4 -3
  51. nucliadb/export_import/exporter.py +11 -19
  52. nucliadb/export_import/importer.py +13 -6
  53. nucliadb/export_import/tasks.py +2 -0
  54. nucliadb/export_import/utils.py +6 -18
  55. nucliadb/health.py +2 -2
  56. nucliadb/ingest/app.py +8 -8
  57. nucliadb/ingest/consumer/consumer.py +8 -10
  58. nucliadb/ingest/consumer/pull.py +3 -8
  59. nucliadb/ingest/consumer/service.py +3 -3
  60. nucliadb/ingest/consumer/utils.py +1 -1
  61. nucliadb/ingest/fields/base.py +28 -49
  62. nucliadb/ingest/fields/conversation.py +12 -12
  63. nucliadb/ingest/fields/exceptions.py +1 -2
  64. nucliadb/ingest/fields/file.py +22 -8
  65. nucliadb/ingest/fields/link.py +7 -7
  66. nucliadb/ingest/fields/text.py +2 -3
  67. nucliadb/ingest/orm/brain_v2.py +78 -64
  68. nucliadb/ingest/orm/broker_message.py +2 -4
  69. nucliadb/ingest/orm/entities.py +10 -209
  70. nucliadb/ingest/orm/index_message.py +4 -4
  71. nucliadb/ingest/orm/knowledgebox.py +18 -27
  72. nucliadb/ingest/orm/processor/auditing.py +1 -3
  73. nucliadb/ingest/orm/processor/data_augmentation.py +1 -2
  74. nucliadb/ingest/orm/processor/processor.py +27 -27
  75. nucliadb/ingest/orm/processor/sequence_manager.py +1 -2
  76. nucliadb/ingest/orm/resource.py +72 -70
  77. nucliadb/ingest/orm/utils.py +1 -1
  78. nucliadb/ingest/processing.py +17 -17
  79. nucliadb/ingest/serialize.py +202 -145
  80. nucliadb/ingest/service/writer.py +3 -109
  81. nucliadb/ingest/settings.py +3 -4
  82. nucliadb/ingest/utils.py +1 -2
  83. nucliadb/learning_proxy.py +11 -11
  84. nucliadb/metrics_exporter.py +5 -4
  85. nucliadb/middleware/__init__.py +82 -1
  86. nucliadb/migrator/datamanager.py +3 -4
  87. nucliadb/migrator/migrator.py +1 -2
  88. nucliadb/migrator/models.py +1 -2
  89. nucliadb/migrator/settings.py +1 -2
  90. nucliadb/models/internal/augment.py +614 -0
  91. nucliadb/models/internal/processing.py +19 -19
  92. nucliadb/openapi.py +2 -2
  93. nucliadb/purge/__init__.py +3 -8
  94. nucliadb/purge/orphan_shards.py +1 -2
  95. nucliadb/reader/__init__.py +5 -0
  96. nucliadb/reader/api/models.py +6 -13
  97. nucliadb/reader/api/v1/download.py +59 -38
  98. nucliadb/reader/api/v1/export_import.py +4 -4
  99. nucliadb/reader/api/v1/learning_config.py +24 -4
  100. nucliadb/reader/api/v1/resource.py +61 -9
  101. nucliadb/reader/api/v1/services.py +18 -14
  102. nucliadb/reader/app.py +3 -1
  103. nucliadb/reader/reader/notifications.py +1 -2
  104. nucliadb/search/api/v1/__init__.py +2 -0
  105. nucliadb/search/api/v1/ask.py +3 -4
  106. nucliadb/search/api/v1/augment.py +585 -0
  107. nucliadb/search/api/v1/catalog.py +11 -15
  108. nucliadb/search/api/v1/find.py +16 -22
  109. nucliadb/search/api/v1/hydrate.py +25 -25
  110. nucliadb/search/api/v1/knowledgebox.py +1 -2
  111. nucliadb/search/api/v1/predict_proxy.py +1 -2
  112. nucliadb/search/api/v1/resource/ask.py +7 -7
  113. nucliadb/search/api/v1/resource/ingestion_agents.py +5 -6
  114. nucliadb/search/api/v1/resource/search.py +9 -11
  115. nucliadb/search/api/v1/retrieve.py +130 -0
  116. nucliadb/search/api/v1/search.py +28 -32
  117. nucliadb/search/api/v1/suggest.py +11 -14
  118. nucliadb/search/api/v1/summarize.py +1 -2
  119. nucliadb/search/api/v1/utils.py +2 -2
  120. nucliadb/search/app.py +3 -2
  121. nucliadb/search/augmentor/__init__.py +21 -0
  122. nucliadb/search/augmentor/augmentor.py +232 -0
  123. nucliadb/search/augmentor/fields.py +704 -0
  124. nucliadb/search/augmentor/metrics.py +24 -0
  125. nucliadb/search/augmentor/paragraphs.py +334 -0
  126. nucliadb/search/augmentor/resources.py +238 -0
  127. nucliadb/search/augmentor/utils.py +33 -0
  128. nucliadb/search/lifecycle.py +3 -1
  129. nucliadb/search/predict.py +24 -17
  130. nucliadb/search/predict_models.py +8 -9
  131. nucliadb/search/requesters/utils.py +11 -10
  132. nucliadb/search/search/cache.py +19 -23
  133. nucliadb/search/search/chat/ask.py +88 -59
  134. nucliadb/search/search/chat/exceptions.py +3 -5
  135. nucliadb/search/search/chat/fetcher.py +201 -0
  136. nucliadb/search/search/chat/images.py +6 -4
  137. nucliadb/search/search/chat/old_prompt.py +1375 -0
  138. nucliadb/search/search/chat/parser.py +510 -0
  139. nucliadb/search/search/chat/prompt.py +563 -615
  140. nucliadb/search/search/chat/query.py +449 -36
  141. nucliadb/search/search/chat/rpc.py +85 -0
  142. nucliadb/search/search/fetch.py +3 -4
  143. nucliadb/search/search/filters.py +8 -11
  144. nucliadb/search/search/find.py +33 -31
  145. nucliadb/search/search/find_merge.py +124 -331
  146. nucliadb/search/search/graph_strategy.py +14 -12
  147. nucliadb/search/search/hydrator/__init__.py +3 -152
  148. nucliadb/search/search/hydrator/fields.py +92 -50
  149. nucliadb/search/search/hydrator/images.py +7 -7
  150. nucliadb/search/search/hydrator/paragraphs.py +42 -26
  151. nucliadb/search/search/hydrator/resources.py +20 -16
  152. nucliadb/search/search/ingestion_agents.py +5 -5
  153. nucliadb/search/search/merge.py +90 -94
  154. nucliadb/search/search/metrics.py +10 -9
  155. nucliadb/search/search/paragraphs.py +7 -9
  156. nucliadb/search/search/predict_proxy.py +13 -9
  157. nucliadb/search/search/query.py +14 -86
  158. nucliadb/search/search/query_parser/fetcher.py +51 -82
  159. nucliadb/search/search/query_parser/models.py +19 -20
  160. nucliadb/search/search/query_parser/old_filters.py +20 -19
  161. nucliadb/search/search/query_parser/parsers/ask.py +4 -5
  162. nucliadb/search/search/query_parser/parsers/catalog.py +5 -6
  163. nucliadb/search/search/query_parser/parsers/common.py +5 -6
  164. nucliadb/search/search/query_parser/parsers/find.py +6 -26
  165. nucliadb/search/search/query_parser/parsers/graph.py +13 -23
  166. nucliadb/search/search/query_parser/parsers/retrieve.py +207 -0
  167. nucliadb/search/search/query_parser/parsers/search.py +15 -53
  168. nucliadb/search/search/query_parser/parsers/unit_retrieval.py +8 -29
  169. nucliadb/search/search/rank_fusion.py +18 -13
  170. nucliadb/search/search/rerankers.py +5 -6
  171. nucliadb/search/search/retrieval.py +300 -0
  172. nucliadb/search/search/summarize.py +5 -6
  173. nucliadb/search/search/utils.py +3 -4
  174. nucliadb/search/settings.py +1 -2
  175. nucliadb/standalone/api_router.py +1 -1
  176. nucliadb/standalone/app.py +4 -3
  177. nucliadb/standalone/auth.py +5 -6
  178. nucliadb/standalone/lifecycle.py +2 -2
  179. nucliadb/standalone/run.py +2 -4
  180. nucliadb/standalone/settings.py +5 -6
  181. nucliadb/standalone/versions.py +3 -4
  182. nucliadb/tasks/consumer.py +13 -8
  183. nucliadb/tasks/models.py +2 -1
  184. nucliadb/tasks/producer.py +3 -3
  185. nucliadb/tasks/retries.py +8 -7
  186. nucliadb/train/api/utils.py +1 -3
  187. nucliadb/train/api/v1/shards.py +1 -2
  188. nucliadb/train/api/v1/trainset.py +1 -2
  189. nucliadb/train/app.py +1 -1
  190. nucliadb/train/generator.py +4 -4
  191. nucliadb/train/generators/field_classifier.py +2 -2
  192. nucliadb/train/generators/field_streaming.py +6 -6
  193. nucliadb/train/generators/image_classifier.py +2 -2
  194. nucliadb/train/generators/paragraph_classifier.py +2 -2
  195. nucliadb/train/generators/paragraph_streaming.py +2 -2
  196. nucliadb/train/generators/question_answer_streaming.py +2 -2
  197. nucliadb/train/generators/sentence_classifier.py +2 -2
  198. nucliadb/train/generators/token_classifier.py +3 -2
  199. nucliadb/train/generators/utils.py +6 -5
  200. nucliadb/train/nodes.py +3 -3
  201. nucliadb/train/resource.py +6 -8
  202. nucliadb/train/settings.py +3 -4
  203. nucliadb/train/types.py +11 -11
  204. nucliadb/train/upload.py +3 -2
  205. nucliadb/train/uploader.py +1 -2
  206. nucliadb/train/utils.py +1 -2
  207. nucliadb/writer/api/v1/export_import.py +4 -1
  208. nucliadb/writer/api/v1/field.py +7 -11
  209. nucliadb/writer/api/v1/knowledgebox.py +3 -4
  210. nucliadb/writer/api/v1/resource.py +9 -20
  211. nucliadb/writer/api/v1/services.py +10 -132
  212. nucliadb/writer/api/v1/upload.py +73 -72
  213. nucliadb/writer/app.py +8 -2
  214. nucliadb/writer/resource/basic.py +12 -15
  215. nucliadb/writer/resource/field.py +7 -5
  216. nucliadb/writer/resource/origin.py +7 -0
  217. nucliadb/writer/settings.py +2 -3
  218. nucliadb/writer/tus/__init__.py +2 -3
  219. nucliadb/writer/tus/azure.py +1 -3
  220. nucliadb/writer/tus/dm.py +3 -3
  221. nucliadb/writer/tus/exceptions.py +3 -4
  222. nucliadb/writer/tus/gcs.py +5 -6
  223. nucliadb/writer/tus/s3.py +2 -3
  224. nucliadb/writer/tus/storage.py +3 -3
  225. {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/METADATA +9 -10
  226. nucliadb-6.10.0.post5705.dist-info/RECORD +410 -0
  227. nucliadb/common/datamanagers/entities.py +0 -139
  228. nucliadb-6.9.1.post5192.dist-info/RECORD +0 -392
  229. {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/WHEEL +0 -0
  230. {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/entry_points.txt +0 -0
  231. {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,8 @@
20
20
  import dataclasses
21
21
  import functools
22
22
  import json
23
- from typing import AsyncGenerator, Optional, Union, cast
23
+ from collections.abc import AsyncGenerator
24
+ from typing import cast
24
25
 
25
26
  from nuclia_models.common.consumption import Consumption
26
27
  from nuclia_models.predict.generative_responses import (
@@ -34,6 +35,7 @@ from nuclia_models.predict.generative_responses import (
34
35
  TextGenerativeResponse,
35
36
  )
36
37
  from pydantic_core import ValidationError
38
+ from typing_extensions import assert_never
37
39
 
38
40
  from nucliadb.common.datamanagers.exceptions import KnowledgeBoxNotFound
39
41
  from nucliadb.common.exceptions import InvalidQueryError
@@ -49,11 +51,13 @@ from nucliadb.search.search.chat.exceptions import (
49
51
  AnswerJsonSchemaTooLong,
50
52
  NoRetrievalResultsError,
51
53
  )
54
+ from nucliadb.search.search.chat.old_prompt import PromptContextBuilder as OldPromptContextBuilder
52
55
  from nucliadb.search.search.chat.prompt import PromptContextBuilder
53
56
  from nucliadb.search.search.chat.query import (
54
57
  NOT_ENOUGH_CONTEXT_ANSWER,
55
58
  ChatAuditor,
56
59
  add_resource_filter,
60
+ get_answer_stream,
57
61
  get_find_results,
58
62
  get_relations_results,
59
63
  maybe_audit_chat,
@@ -69,11 +73,15 @@ from nucliadb.search.search.metrics import AskMetrics, Metrics
69
73
  from nucliadb.search.search.query_parser.fetcher import Fetcher
70
74
  from nucliadb.search.search.query_parser.parsers.ask import fetcher_for_ask, parse_ask
71
75
  from nucliadb.search.search.rank_fusion import WeightedCombSum
72
- from nucliadb.search.search.rerankers import (
73
- get_reranker,
76
+ from nucliadb_models.retrieval import (
77
+ GraphScore,
78
+ KeywordScore,
79
+ RerankerScore,
80
+ RrfScore,
81
+ SemanticScore,
74
82
  )
75
- from nucliadb.search.utilities import get_predict
76
83
  from nucliadb_models.search import (
84
+ SCORE_TYPE,
77
85
  AnswerAskResponseItem,
78
86
  AskRequest,
79
87
  AskResponseItem,
@@ -118,7 +126,9 @@ from nucliadb_models.search import (
118
126
  parse_rephrase_prompt,
119
127
  )
120
128
  from nucliadb_telemetry import errors
129
+ from nucliadb_utils import const
121
130
  from nucliadb_utils.exceptions import LimitsExceededError
131
+ from nucliadb_utils.utilities import has_feature
122
132
 
123
133
 
124
134
  @dataclasses.dataclass
@@ -132,7 +142,7 @@ class RetrievalResults:
132
142
  main_query: KnowledgeboxFindResults
133
143
  fetcher: Fetcher
134
144
  main_query_weight: float
135
- prequeries: Optional[list[PreQueryResult]] = None
145
+ prequeries: list[PreQueryResult] | None = None
136
146
  best_matches: list[RetrievalMatch] = dataclasses.field(default_factory=list)
137
147
 
138
148
 
@@ -143,15 +153,15 @@ class AskResult:
143
153
  kbid: str,
144
154
  ask_request: AskRequest,
145
155
  main_results: KnowledgeboxFindResults,
146
- prequeries_results: Optional[list[PreQueryResult]],
147
- nuclia_learning_id: Optional[str],
148
- predict_answer_stream: Optional[AsyncGenerator[GenerativeChunk, None]],
156
+ prequeries_results: list[PreQueryResult] | None,
157
+ nuclia_learning_id: str | None,
158
+ predict_answer_stream: AsyncGenerator[GenerativeChunk, None] | None,
149
159
  prompt_context: PromptContext,
150
160
  prompt_context_order: PromptContextOrder,
151
161
  auditor: ChatAuditor,
152
162
  metrics: AskMetrics,
153
163
  best_matches: list[RetrievalMatch],
154
- debug_chat_model: Optional[ChatModel],
164
+ debug_chat_model: ChatModel | None,
155
165
  augmented_context: AugmentedContext,
156
166
  ):
157
167
  # Initial attributes
@@ -171,14 +181,14 @@ class AskResult:
171
181
 
172
182
  # Computed from the predict chat answer stream
173
183
  self._answer_text = ""
174
- self._reasoning_text: Optional[str] = None
175
- self._object: Optional[JSONGenerativeResponse] = None
176
- self._status: Optional[StatusGenerativeResponse] = None
177
- self._citations: Optional[CitationsGenerativeResponse] = None
178
- self._footnote_citations: Optional[FootnoteCitationsGenerativeResponse] = None
179
- self._metadata: Optional[MetaGenerativeResponse] = None
180
- self._relations: Optional[Relations] = None
181
- self._consumption: Optional[Consumption] = None
184
+ self._reasoning_text: str | None = None
185
+ self._object: JSONGenerativeResponse | None = None
186
+ self._status: StatusGenerativeResponse | None = None
187
+ self._citations: CitationsGenerativeResponse | None = None
188
+ self._footnote_citations: FootnoteCitationsGenerativeResponse | None = None
189
+ self._metadata: MetaGenerativeResponse | None = None
190
+ self._relations: Relations | None = None
191
+ self._consumption: Consumption | None = None
182
192
 
183
193
  @property
184
194
  def status_code(self) -> AnswerStatusCode:
@@ -187,7 +197,7 @@ class AskResult:
187
197
  return AnswerStatusCode(self._status.code)
188
198
 
189
199
  @property
190
- def status_error_details(self) -> Optional[str]:
200
+ def status_error_details(self) -> str | None:
191
201
  if self._status is None: # pragma: no cover
192
202
  return None
193
203
  return self._status.details
@@ -240,9 +250,7 @@ class AskResult:
240
250
  self.metrics.record_first_reasoning_chunk_yielded()
241
251
  first_reasoning_chunk_yielded = True
242
252
  else:
243
- # This is a trick so mypy generates an error if this branch can be reached,
244
- # that is, if we are missing some ifs
245
- _a: int = "a"
253
+ assert_never(answer_chunk)
246
254
 
247
255
  if self._object is not None:
248
256
  yield JSONAskResponseItem(object=self._object.object)
@@ -396,7 +404,7 @@ class AskResult:
396
404
  if self._object is not None:
397
405
  answer_json = self._object.object
398
406
 
399
- prequeries_results: Optional[dict[str, KnowledgeboxFindResults]] = None
407
+ prequeries_results: dict[str, KnowledgeboxFindResults] | None = None
400
408
  if self.prequeries_results:
401
409
  prequeries_results = {}
402
410
  for index, (prequery, result) in enumerate(self.prequeries_results):
@@ -452,7 +460,7 @@ class AskResult:
452
460
 
453
461
  async def _stream_predict_answer_text(
454
462
  self,
455
- ) -> AsyncGenerator[Union[TextGenerativeResponse, ReasoningGenerativeResponse], None]:
463
+ ) -> AsyncGenerator[TextGenerativeResponse | ReasoningGenerativeResponse, None]:
456
464
  """
457
465
  Reads the stream of the generative model, yielding the answer text but also parsing
458
466
  other items like status codes, citations and miscellaneous metadata.
@@ -496,8 +504,8 @@ class AskResult:
496
504
  class NotEnoughContextAskResult(AskResult):
497
505
  def __init__(
498
506
  self,
499
- main_results: Optional[KnowledgeboxFindResults] = None,
500
- prequeries_results: Optional[list[PreQueryResult]] = None,
507
+ main_results: KnowledgeboxFindResults | None = None,
508
+ prequeries_results: list[PreQueryResult] | None = None,
501
509
  ):
502
510
  self.main_results = main_results or KnowledgeboxFindResults(resources={}, min_score=None)
503
511
  self.prequeries_results = prequeries_results or []
@@ -547,8 +555,8 @@ async def ask(
547
555
  user_id: str,
548
556
  client_type: NucliaDBClientType,
549
557
  origin: str,
550
- resource: Optional[str] = None,
551
- extra_predict_headers: Optional[dict[str, str]] = None,
558
+ resource: str | None = None,
559
+ extra_predict_headers: dict[str, str] | None = None,
552
560
  ) -> AskResult:
553
561
  metrics = AskMetrics()
554
562
  chat_history = ask_request.chat_history or []
@@ -627,19 +635,36 @@ async def ask(
627
635
 
628
636
  # Now we build the prompt context
629
637
  with metrics.time("context_building"):
630
- prompt_context_builder = PromptContextBuilder(
631
- kbid=kbid,
632
- ordered_paragraphs=[match.paragraph for match in retrieval_results.best_matches],
633
- resource=resource,
634
- user_context=user_context,
635
- user_image_context=ask_request.extra_context_images,
636
- strategies=ask_request.rag_strategies,
637
- image_strategies=ask_request.rag_images_strategies,
638
- max_context_characters=tokens_to_chars(generation.max_context_tokens),
639
- visual_llm=generation.use_visual_llm,
640
- query_image=ask_request.query_image,
641
- metrics=metrics.child_span("context_building"),
642
- )
638
+ prompt_context_builder: PromptContextBuilder | OldPromptContextBuilder
639
+ if has_feature(const.Features.ASK_DECOUPLED, context={"kbid": kbid}):
640
+ prompt_context_builder = PromptContextBuilder(
641
+ kbid=kbid,
642
+ ordered_paragraphs=[match.paragraph for match in retrieval_results.best_matches],
643
+ resource=resource,
644
+ user_context=user_context,
645
+ user_image_context=ask_request.extra_context_images,
646
+ strategies=ask_request.rag_strategies,
647
+ image_strategies=ask_request.rag_images_strategies,
648
+ max_context_characters=tokens_to_chars(generation.max_context_tokens),
649
+ visual_llm=generation.use_visual_llm,
650
+ query_image=ask_request.query_image,
651
+ metrics=metrics.child_span("context_building"),
652
+ )
653
+ else:
654
+ prompt_context_builder = OldPromptContextBuilder(
655
+ kbid=kbid,
656
+ ordered_paragraphs=[match.paragraph for match in retrieval_results.best_matches],
657
+ resource=resource,
658
+ user_context=user_context,
659
+ user_image_context=ask_request.extra_context_images,
660
+ strategies=ask_request.rag_strategies,
661
+ image_strategies=ask_request.rag_images_strategies,
662
+ max_context_characters=tokens_to_chars(generation.max_context_tokens),
663
+ visual_llm=generation.use_visual_llm,
664
+ query_image=ask_request.query_image,
665
+ metrics=metrics.child_span("context_building"),
666
+ )
667
+
643
668
  (
644
669
  prompt_context,
645
670
  prompt_context_order,
@@ -675,14 +700,11 @@ async def ask(
675
700
  predict_answer_stream = None
676
701
  if ask_request.generate_answer:
677
702
  with metrics.time("stream_start"):
678
- predict = get_predict()
679
703
  (
680
704
  nuclia_learning_id,
681
705
  nuclia_learning_model,
682
706
  predict_answer_stream,
683
- ) = await predict.chat_query_ndjson(
684
- kbid=kbid, item=chat_model, extra_headers=extra_predict_headers
685
- )
707
+ ) = await get_answer_stream(kbid=kbid, item=chat_model, extra_headers=extra_predict_headers)
686
708
 
687
709
  auditor = ChatAuditor(
688
710
  kbid=kbid,
@@ -757,7 +779,7 @@ def handled_ask_exceptions(func):
757
779
  return wrapper
758
780
 
759
781
 
760
- def parse_prequeries(ask_request: AskRequest) -> Optional[PreQueriesStrategy]:
782
+ def parse_prequeries(ask_request: AskRequest) -> PreQueriesStrategy | None:
761
783
  query_ids = []
762
784
  for rag_strategy in ask_request.rag_strategies:
763
785
  if rag_strategy.name == RagStrategyName.PREQUERIES:
@@ -776,7 +798,7 @@ def parse_prequeries(ask_request: AskRequest) -> Optional[PreQueriesStrategy]:
776
798
  return None
777
799
 
778
800
 
779
- def parse_graph_strategy(ask_request: AskRequest) -> Optional[GraphStrategy]:
801
+ def parse_graph_strategy(ask_request: AskRequest) -> GraphStrategy | None:
780
802
  for rag_strategy in ask_request.rag_strategies:
781
803
  if rag_strategy.name == RagStrategyName.GRAPH:
782
804
  return cast(GraphStrategy, rag_strategy)
@@ -791,7 +813,7 @@ async def retrieval_step(
791
813
  user_id: str,
792
814
  origin: str,
793
815
  metrics: Metrics,
794
- resource: Optional[str] = None,
816
+ resource: str | None = None,
795
817
  ) -> RetrievalResults:
796
818
  """
797
819
  This function encapsulates all the logic related to retrieval in the ask endpoint.
@@ -830,7 +852,7 @@ async def retrieval_in_kb(
830
852
  ) -> RetrievalResults:
831
853
  prequeries = parse_prequeries(ask_request)
832
854
  graph_strategy = parse_graph_strategy(ask_request)
833
- main_results, prequeries_results, parsed_query = await get_find_results(
855
+ main_results, prequeries_results, fetcher, reranker = await get_find_results(
834
856
  kbid=kbid,
835
857
  query=main_query,
836
858
  item=ask_request,
@@ -842,10 +864,6 @@ async def retrieval_in_kb(
842
864
  )
843
865
 
844
866
  if graph_strategy is not None:
845
- assert parsed_query.retrieval.reranker is not None, (
846
- "find parser must provide a reranking algorithm"
847
- )
848
- reranker = get_reranker(parsed_query.retrieval.reranker)
849
867
  graph_results, graph_request = await get_graph_results(
850
868
  kbid=kbid,
851
869
  query=main_query,
@@ -878,7 +896,7 @@ async def retrieval_in_kb(
878
896
  return RetrievalResults(
879
897
  main_query=main_results,
880
898
  prequeries=prequeries_results,
881
- fetcher=parsed_query.fetcher,
899
+ fetcher=fetcher,
882
900
  main_query_weight=main_query_weight,
883
901
  best_matches=best_matches,
884
902
  )
@@ -918,7 +936,7 @@ async def retrieval_in_resource(
918
936
  )
919
937
  add_resource_filter(prequery.request, [resource])
920
938
 
921
- main_results, prequeries_results, parsed_query = await get_find_results(
939
+ main_results, prequeries_results, fetcher, _ = await get_find_results(
922
940
  kbid=kbid,
923
941
  query=main_query,
924
942
  item=ask_request,
@@ -941,7 +959,7 @@ async def retrieval_in_resource(
941
959
  return RetrievalResults(
942
960
  main_query=main_results,
943
961
  prequeries=prequeries_results,
944
- fetcher=parsed_query.fetcher,
962
+ fetcher=fetcher,
945
963
  main_query_weight=main_query_weight,
946
964
  best_matches=best_matches,
947
965
  )
@@ -953,7 +971,7 @@ class _FindParagraph(ScoredTextBlock):
953
971
 
954
972
  def compute_best_matches(
955
973
  main_results: KnowledgeboxFindResults,
956
- prequeries_results: Optional[list[PreQueryResult]] = None,
974
+ prequeries_results: list[PreQueryResult] | None = None,
957
975
  main_query_weight: float = 1.0,
958
976
  ) -> list[RetrievalMatch]:
959
977
  """
@@ -968,15 +986,27 @@ def compute_best_matches(
968
986
  `main_query_weight` is the weight given to the paragraphs matching the main query when calculating the final score.
969
987
  """
970
988
 
989
+ score_type_map = {
990
+ SCORE_TYPE.VECTOR: SemanticScore,
991
+ SCORE_TYPE.BM25: KeywordScore,
992
+ SCORE_TYPE.BOTH: RrfScore, # /find only exposes RRF as rank fusion algorithm
993
+ SCORE_TYPE.RERANKER: RerankerScore,
994
+ SCORE_TYPE.RELATION_RELEVANCE: GraphScore,
995
+ }
996
+
971
997
  def extract_paragraphs(results: KnowledgeboxFindResults) -> list[_FindParagraph]:
972
998
  paragraphs = []
973
999
  for resource in results.resources.values():
974
1000
  for field in resource.fields.values():
975
1001
  for paragraph in field.paragraphs.values():
1002
+ # TODO(decoupled-ask): we don't know the score history, as
1003
+ # we are using find results. Once we move boolean queries
1004
+ # inside the new retrieval flow we'll move this and have the
1005
+ # proper information to do this rank fusion
976
1006
  paragraphs.append(
977
1007
  _FindParagraph(
978
1008
  paragraph_id=ParagraphId.from_string(paragraph.id),
979
- score=paragraph.score,
1009
+ scores=[score_type_map[paragraph.score_type](score=paragraph.score)],
980
1010
  score_type=paragraph.score_type,
981
1011
  original=paragraph,
982
1012
  )
@@ -1012,7 +1042,7 @@ def compute_best_matches(
1012
1042
 
1013
1043
  def calculate_prequeries_for_json_schema(
1014
1044
  ask_request: AskRequest,
1015
- ) -> Optional[PreQueriesStrategy]:
1045
+ ) -> PreQueriesStrategy | None:
1016
1046
  """
1017
1047
  This function generates a PreQueriesStrategy with a query for each property in the JSON schema
1018
1048
  found in ask_request.answer_json_schema.
@@ -1077,7 +1107,6 @@ def calculate_prequeries_for_json_schema(
1077
1107
  rephrase=ask_request.rephrase,
1078
1108
  rephrase_prompt=parse_rephrase_prompt(ask_request),
1079
1109
  security=ask_request.security,
1080
- autofilter=False,
1081
1110
  )
1082
1111
  prequery = PreQuery(
1083
1112
  request=req,
@@ -19,17 +19,15 @@
19
19
  #
20
20
 
21
21
 
22
- from typing import Optional
23
-
24
22
  from nucliadb_models.search import KnowledgeboxFindResults, PreQueryResult
25
23
 
26
24
 
27
25
  class NoRetrievalResultsError(Exception):
28
26
  def __init__(
29
27
  self,
30
- main: Optional[KnowledgeboxFindResults] = None,
31
- prequeries: Optional[list[PreQueryResult]] = None,
32
- prefilters: Optional[list[PreQueryResult]] = None,
28
+ main: KnowledgeboxFindResults | None = None,
29
+ prequeries: list[PreQueryResult] | None = None,
30
+ prefilters: list[PreQueryResult] | None = None,
33
31
  ):
34
32
  self.main_query = main
35
33
  self.prequeries = prequeries
@@ -0,0 +1,201 @@
1
+ # Copyright (C) 2021 Bosutech XXI S.L.
2
+ #
3
+ # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
+ # For commercial licensing, contact us at info@nuclia.com.
5
+ #
6
+ # AGPL:
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU Affero General Public License as
9
+ # published by the Free Software Foundation, either version 3 of the
10
+ # License, or (at your option) any later version.
11
+ #
12
+ # This program is distributed in the hope that it will be useful,
13
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
+ # GNU Affero General Public License for more details.
16
+ #
17
+ # You should have received a copy of the GNU Affero General Public License
18
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
+ #
20
+
21
+
22
+ from google.protobuf.json_format import ParseDict
23
+
24
+ from nucliadb.common.exceptions import InvalidQueryError
25
+ from nucliadb.search import logger
26
+ from nucliadb.search.predict import SendToPredictError, convert_relations
27
+ from nucliadb.search.predict_models import QueryModel
28
+ from nucliadb.search.search.chat import rpc
29
+ from nucliadb.search.search.query_parser.fetcher import Fetcher
30
+ from nucliadb.search.utilities import get_predict
31
+ from nucliadb_models.internal.predict import QueryInfo
32
+ from nucliadb_models.search import Image, MaxTokens
33
+ from nucliadb_protos import knowledgebox_pb2, utils_pb2
34
+
35
+
36
+ class RAOFetcher(Fetcher):
37
+ def __init__(
38
+ self,
39
+ kbid: str,
40
+ *,
41
+ query: str,
42
+ user_vector: list[float] | None,
43
+ vectorset: str | None,
44
+ rephrase: bool,
45
+ rephrase_prompt: str | None,
46
+ generative_model: str | None,
47
+ query_image: Image | None,
48
+ ):
49
+ super().__init__(
50
+ kbid,
51
+ query=query,
52
+ user_vector=user_vector,
53
+ vectorset=vectorset,
54
+ rephrase=rephrase,
55
+ rephrase_prompt=rephrase_prompt,
56
+ generative_model=generative_model,
57
+ query_image=query_image,
58
+ )
59
+
60
+ self._query_info: QueryInfo | None = None
61
+ self._vectorset: str | None = None
62
+
63
+ async def query_information(self) -> QueryInfo:
64
+ if self._query_info is None:
65
+ self._query_info = await query_information(
66
+ kbid=self.kbid,
67
+ query=self.query,
68
+ semantic_model=self.user_vectorset,
69
+ generative_model=self.generative_model,
70
+ rephrase=self.rephrase,
71
+ rephrase_prompt=self.rephrase_prompt,
72
+ query_image=self.query_image,
73
+ )
74
+ return self._query_info
75
+
76
+ # Retrieval
77
+
78
+ async def get_rephrased_query(self) -> str | None:
79
+ query_info = await self.query_information()
80
+ return query_info.rephrased_query
81
+
82
+ async def get_detected_entities(self) -> list[utils_pb2.RelationNode]:
83
+ query_info = await self.query_information()
84
+ if query_info.entities is not None:
85
+ detected_entities = convert_relations(query_info.entities.model_dump())
86
+ else:
87
+ detected_entities = []
88
+ return detected_entities
89
+
90
+ async def get_semantic_min_score(self) -> float | None:
91
+ query_info = await self.query_information()
92
+ vectorset = await self.get_vectorset()
93
+ return query_info.semantic_thresholds.get(vectorset, None)
94
+
95
+ async def get_vectorset(self) -> str:
96
+ if self._vectorset is None:
97
+ if self.user_vectorset is not None:
98
+ self._vectorset = self.user_vectorset
99
+ else:
100
+ # when it's not provided, we get the default from Predict API
101
+ query_info = await self.query_information()
102
+ if query_info.sentence is None or len(query_info.sentence.vectors) == 0:
103
+ logger.error(
104
+ "Asking for a vectorset but /query didn't return one", extra={"kbid": self.kbid}
105
+ )
106
+ raise SendToPredictError("Predict API didn't return a sentence vectorset")
107
+ # vectors field is enforced by the data model to have at least one key
108
+ for vectorset in query_info.sentence.vectors.keys():
109
+ self._vectorset = vectorset
110
+ break
111
+ assert self._vectorset is not None
112
+ return self._vectorset
113
+
114
+ async def get_query_vector(self) -> list[float]:
115
+ if self.user_vector is not None:
116
+ return self.user_vector
117
+
118
+ query_info = await self.query_information()
119
+ if query_info.sentence is None:
120
+ logger.error(
121
+ "Asking for a semantic query vector but /query didn't return a sentence",
122
+ extra={"kbid": self.kbid},
123
+ )
124
+ raise SendToPredictError("Predict API didn't return a sentence for semantic search")
125
+
126
+ vectorset = await self.get_vectorset()
127
+ if vectorset not in query_info.sentence.vectors:
128
+ logger.error(
129
+ "Predict is not responding with a valid query nucliadb vectorset",
130
+ extra={
131
+ "kbid": self.kbid,
132
+ "vectorset": vectorset,
133
+ "predict_vectorsets": ",".join(query_info.sentence.vectors.keys()),
134
+ },
135
+ )
136
+ raise SendToPredictError("Predict API didn't return the requested vectorset")
137
+
138
+ query_vector = query_info.sentence.vectors[vectorset]
139
+ return query_vector
140
+
141
+ async def get_classification_labels(self) -> knowledgebox_pb2.Labels:
142
+ labelsets = await rpc.labelsets(self.kbid)
143
+
144
+ # TODO(decoupled-ask): remove this conversion and refactor code to use API models instead of protobuf
145
+ kb_labels = knowledgebox_pb2.Labels()
146
+ for labelset, labels in labelsets.labelsets.items():
147
+ ParseDict(labels.model_dump(), kb_labels.labelset[labelset])
148
+
149
+ return kb_labels
150
+
151
+ # Generative
152
+
153
+ async def get_visual_llm_enabled(self) -> bool:
154
+ query_info = await self.query_information()
155
+ if query_info is None:
156
+ raise SendToPredictError("Error while using predict's query endpoint")
157
+
158
+ return query_info.visual_llm
159
+
160
+ async def get_max_context_tokens(self, max_tokens: MaxTokens | None) -> int:
161
+ query_info = await self.query_information()
162
+ if query_info is None:
163
+ raise SendToPredictError("Error while using predict's query endpoint")
164
+
165
+ model_max = query_info.max_context
166
+ if max_tokens is not None and max_tokens.context is not None:
167
+ if max_tokens.context > model_max:
168
+ raise InvalidQueryError(
169
+ "max_tokens.context",
170
+ f"Max context tokens is higher than the model's limit of {model_max}",
171
+ )
172
+ return max_tokens.context
173
+ return model_max
174
+
175
+ def get_max_answer_tokens(self, max_tokens: MaxTokens | None) -> int | None:
176
+ if max_tokens is not None and max_tokens.answer is not None:
177
+ return max_tokens.answer
178
+ return None
179
+
180
+
181
+ async def query_information(
182
+ kbid: str,
183
+ query: str,
184
+ semantic_model: str | None,
185
+ generative_model: str | None = None,
186
+ rephrase: bool = False,
187
+ rephrase_prompt: str | None = None,
188
+ query_image: Image | None = None,
189
+ ) -> QueryInfo:
190
+ # NOTE: When moving /ask to RAO, this will need to change to whatever client/utility is used
191
+ # to call NUA predict (internally or externally in the case of onprem).
192
+ predict = get_predict()
193
+ item = QueryModel(
194
+ text=query,
195
+ semantic_models=[semantic_model] if semantic_model else None,
196
+ generative_model=generative_model,
197
+ rephrase=rephrase,
198
+ rephrase_prompt=rephrase_prompt,
199
+ query_image=query_image,
200
+ )
201
+ return await predict.query(kbid, item)
@@ -19,7 +19,6 @@
19
19
 
20
20
  import base64
21
21
  from io import BytesIO
22
- from typing import Optional
23
22
 
24
23
  from nucliadb.common.ids import ParagraphId
25
24
  from nucliadb.ingest.fields.file import File
@@ -29,7 +28,8 @@ from nucliadb_utils.storages.storage import Storage
29
28
  from nucliadb_utils.utilities import get_storage
30
29
 
31
30
 
32
- async def get_page_image(kbid: str, paragraph_id: ParagraphId, page_number: int) -> Optional[Image]:
31
+ # DEPRECATED(decoupled-ask): remove once old_prompt.py is removed
32
+ async def get_page_image(kbid: str, paragraph_id: ParagraphId, page_number: int) -> Image | None:
33
33
  storage = await get_storage(service_name=SERVICE_NAME)
34
34
  sf = storage.file_extracted(
35
35
  kbid=kbid,
@@ -48,7 +48,8 @@ async def get_page_image(kbid: str, paragraph_id: ParagraphId, page_number: int)
48
48
  return image
49
49
 
50
50
 
51
- async def get_paragraph_image(kbid: str, paragraph_id: ParagraphId, reference: str) -> Optional[Image]:
51
+ # DEPRECATED(decoupled-ask): remove once old_prompt.py is removed
52
+ async def get_paragraph_image(kbid: str, paragraph_id: ParagraphId, reference: str) -> Image | None:
52
53
  storage = await get_storage(service_name=SERVICE_NAME)
53
54
  sf = storage.file_extracted(
54
55
  kbid=kbid,
@@ -67,7 +68,8 @@ async def get_paragraph_image(kbid: str, paragraph_id: ParagraphId, reference: s
67
68
  return image
68
69
 
69
70
 
70
- async def get_file_thumbnail_image(file: File) -> Optional[Image]:
71
+ # DEPRECATED(decoupled-ask): remove once old_prompt.py is removed
72
+ async def get_file_thumbnail_image(file: File) -> Image | None:
71
73
  fed = await file.get_file_extracted_data()
72
74
  if fed is None or not fed.HasField("file_thumbnail"):
73
75
  return None