nucliadb 6.9.1.post5192__py3-none-any.whl → 6.10.0.post5705__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (231) hide show
  1. migrations/0023_backfill_pg_catalog.py +2 -2
  2. migrations/0029_backfill_field_status.py +3 -4
  3. migrations/0032_remove_old_relations.py +2 -3
  4. migrations/0038_backfill_catalog_field_labels.py +2 -2
  5. migrations/0039_backfill_converation_splits_metadata.py +2 -2
  6. migrations/0041_reindex_conversations.py +137 -0
  7. migrations/pg/0010_shards_index.py +34 -0
  8. nucliadb/search/api/v1/resource/utils.py → migrations/pg/0011_catalog_statistics.py +5 -6
  9. migrations/pg/0012_catalog_statistics_undo.py +26 -0
  10. nucliadb/backups/create.py +2 -15
  11. nucliadb/backups/restore.py +4 -15
  12. nucliadb/backups/tasks.py +4 -1
  13. nucliadb/common/back_pressure/cache.py +2 -3
  14. nucliadb/common/back_pressure/materializer.py +7 -13
  15. nucliadb/common/back_pressure/settings.py +6 -6
  16. nucliadb/common/back_pressure/utils.py +1 -0
  17. nucliadb/common/cache.py +9 -9
  18. nucliadb/common/catalog/interface.py +12 -12
  19. nucliadb/common/catalog/pg.py +41 -29
  20. nucliadb/common/catalog/utils.py +3 -3
  21. nucliadb/common/cluster/manager.py +5 -4
  22. nucliadb/common/cluster/rebalance.py +483 -114
  23. nucliadb/common/cluster/rollover.py +25 -9
  24. nucliadb/common/cluster/settings.py +3 -8
  25. nucliadb/common/cluster/utils.py +34 -8
  26. nucliadb/common/context/__init__.py +7 -8
  27. nucliadb/common/context/fastapi.py +1 -2
  28. nucliadb/common/datamanagers/__init__.py +2 -4
  29. nucliadb/common/datamanagers/atomic.py +4 -2
  30. nucliadb/common/datamanagers/cluster.py +1 -2
  31. nucliadb/common/datamanagers/fields.py +3 -4
  32. nucliadb/common/datamanagers/kb.py +6 -6
  33. nucliadb/common/datamanagers/labels.py +2 -3
  34. nucliadb/common/datamanagers/resources.py +10 -33
  35. nucliadb/common/datamanagers/rollover.py +5 -7
  36. nucliadb/common/datamanagers/search_configurations.py +1 -2
  37. nucliadb/common/datamanagers/synonyms.py +1 -2
  38. nucliadb/common/datamanagers/utils.py +4 -4
  39. nucliadb/common/datamanagers/vectorsets.py +4 -4
  40. nucliadb/common/external_index_providers/base.py +32 -5
  41. nucliadb/common/external_index_providers/manager.py +4 -5
  42. nucliadb/common/filter_expression.py +128 -40
  43. nucliadb/common/http_clients/processing.py +12 -23
  44. nucliadb/common/ids.py +6 -4
  45. nucliadb/common/locking.py +1 -2
  46. nucliadb/common/maindb/driver.py +9 -8
  47. nucliadb/common/maindb/local.py +5 -5
  48. nucliadb/common/maindb/pg.py +9 -8
  49. nucliadb/common/nidx.py +3 -4
  50. nucliadb/export_import/datamanager.py +4 -3
  51. nucliadb/export_import/exporter.py +11 -19
  52. nucliadb/export_import/importer.py +13 -6
  53. nucliadb/export_import/tasks.py +2 -0
  54. nucliadb/export_import/utils.py +6 -18
  55. nucliadb/health.py +2 -2
  56. nucliadb/ingest/app.py +8 -8
  57. nucliadb/ingest/consumer/consumer.py +8 -10
  58. nucliadb/ingest/consumer/pull.py +3 -8
  59. nucliadb/ingest/consumer/service.py +3 -3
  60. nucliadb/ingest/consumer/utils.py +1 -1
  61. nucliadb/ingest/fields/base.py +28 -49
  62. nucliadb/ingest/fields/conversation.py +12 -12
  63. nucliadb/ingest/fields/exceptions.py +1 -2
  64. nucliadb/ingest/fields/file.py +22 -8
  65. nucliadb/ingest/fields/link.py +7 -7
  66. nucliadb/ingest/fields/text.py +2 -3
  67. nucliadb/ingest/orm/brain_v2.py +78 -64
  68. nucliadb/ingest/orm/broker_message.py +2 -4
  69. nucliadb/ingest/orm/entities.py +10 -209
  70. nucliadb/ingest/orm/index_message.py +4 -4
  71. nucliadb/ingest/orm/knowledgebox.py +18 -27
  72. nucliadb/ingest/orm/processor/auditing.py +1 -3
  73. nucliadb/ingest/orm/processor/data_augmentation.py +1 -2
  74. nucliadb/ingest/orm/processor/processor.py +27 -27
  75. nucliadb/ingest/orm/processor/sequence_manager.py +1 -2
  76. nucliadb/ingest/orm/resource.py +72 -70
  77. nucliadb/ingest/orm/utils.py +1 -1
  78. nucliadb/ingest/processing.py +17 -17
  79. nucliadb/ingest/serialize.py +202 -145
  80. nucliadb/ingest/service/writer.py +3 -109
  81. nucliadb/ingest/settings.py +3 -4
  82. nucliadb/ingest/utils.py +1 -2
  83. nucliadb/learning_proxy.py +11 -11
  84. nucliadb/metrics_exporter.py +5 -4
  85. nucliadb/middleware/__init__.py +82 -1
  86. nucliadb/migrator/datamanager.py +3 -4
  87. nucliadb/migrator/migrator.py +1 -2
  88. nucliadb/migrator/models.py +1 -2
  89. nucliadb/migrator/settings.py +1 -2
  90. nucliadb/models/internal/augment.py +614 -0
  91. nucliadb/models/internal/processing.py +19 -19
  92. nucliadb/openapi.py +2 -2
  93. nucliadb/purge/__init__.py +3 -8
  94. nucliadb/purge/orphan_shards.py +1 -2
  95. nucliadb/reader/__init__.py +5 -0
  96. nucliadb/reader/api/models.py +6 -13
  97. nucliadb/reader/api/v1/download.py +59 -38
  98. nucliadb/reader/api/v1/export_import.py +4 -4
  99. nucliadb/reader/api/v1/learning_config.py +24 -4
  100. nucliadb/reader/api/v1/resource.py +61 -9
  101. nucliadb/reader/api/v1/services.py +18 -14
  102. nucliadb/reader/app.py +3 -1
  103. nucliadb/reader/reader/notifications.py +1 -2
  104. nucliadb/search/api/v1/__init__.py +2 -0
  105. nucliadb/search/api/v1/ask.py +3 -4
  106. nucliadb/search/api/v1/augment.py +585 -0
  107. nucliadb/search/api/v1/catalog.py +11 -15
  108. nucliadb/search/api/v1/find.py +16 -22
  109. nucliadb/search/api/v1/hydrate.py +25 -25
  110. nucliadb/search/api/v1/knowledgebox.py +1 -2
  111. nucliadb/search/api/v1/predict_proxy.py +1 -2
  112. nucliadb/search/api/v1/resource/ask.py +7 -7
  113. nucliadb/search/api/v1/resource/ingestion_agents.py +5 -6
  114. nucliadb/search/api/v1/resource/search.py +9 -11
  115. nucliadb/search/api/v1/retrieve.py +130 -0
  116. nucliadb/search/api/v1/search.py +28 -32
  117. nucliadb/search/api/v1/suggest.py +11 -14
  118. nucliadb/search/api/v1/summarize.py +1 -2
  119. nucliadb/search/api/v1/utils.py +2 -2
  120. nucliadb/search/app.py +3 -2
  121. nucliadb/search/augmentor/__init__.py +21 -0
  122. nucliadb/search/augmentor/augmentor.py +232 -0
  123. nucliadb/search/augmentor/fields.py +704 -0
  124. nucliadb/search/augmentor/metrics.py +24 -0
  125. nucliadb/search/augmentor/paragraphs.py +334 -0
  126. nucliadb/search/augmentor/resources.py +238 -0
  127. nucliadb/search/augmentor/utils.py +33 -0
  128. nucliadb/search/lifecycle.py +3 -1
  129. nucliadb/search/predict.py +24 -17
  130. nucliadb/search/predict_models.py +8 -9
  131. nucliadb/search/requesters/utils.py +11 -10
  132. nucliadb/search/search/cache.py +19 -23
  133. nucliadb/search/search/chat/ask.py +88 -59
  134. nucliadb/search/search/chat/exceptions.py +3 -5
  135. nucliadb/search/search/chat/fetcher.py +201 -0
  136. nucliadb/search/search/chat/images.py +6 -4
  137. nucliadb/search/search/chat/old_prompt.py +1375 -0
  138. nucliadb/search/search/chat/parser.py +510 -0
  139. nucliadb/search/search/chat/prompt.py +563 -615
  140. nucliadb/search/search/chat/query.py +449 -36
  141. nucliadb/search/search/chat/rpc.py +85 -0
  142. nucliadb/search/search/fetch.py +3 -4
  143. nucliadb/search/search/filters.py +8 -11
  144. nucliadb/search/search/find.py +33 -31
  145. nucliadb/search/search/find_merge.py +124 -331
  146. nucliadb/search/search/graph_strategy.py +14 -12
  147. nucliadb/search/search/hydrator/__init__.py +3 -152
  148. nucliadb/search/search/hydrator/fields.py +92 -50
  149. nucliadb/search/search/hydrator/images.py +7 -7
  150. nucliadb/search/search/hydrator/paragraphs.py +42 -26
  151. nucliadb/search/search/hydrator/resources.py +20 -16
  152. nucliadb/search/search/ingestion_agents.py +5 -5
  153. nucliadb/search/search/merge.py +90 -94
  154. nucliadb/search/search/metrics.py +10 -9
  155. nucliadb/search/search/paragraphs.py +7 -9
  156. nucliadb/search/search/predict_proxy.py +13 -9
  157. nucliadb/search/search/query.py +14 -86
  158. nucliadb/search/search/query_parser/fetcher.py +51 -82
  159. nucliadb/search/search/query_parser/models.py +19 -20
  160. nucliadb/search/search/query_parser/old_filters.py +20 -19
  161. nucliadb/search/search/query_parser/parsers/ask.py +4 -5
  162. nucliadb/search/search/query_parser/parsers/catalog.py +5 -6
  163. nucliadb/search/search/query_parser/parsers/common.py +5 -6
  164. nucliadb/search/search/query_parser/parsers/find.py +6 -26
  165. nucliadb/search/search/query_parser/parsers/graph.py +13 -23
  166. nucliadb/search/search/query_parser/parsers/retrieve.py +207 -0
  167. nucliadb/search/search/query_parser/parsers/search.py +15 -53
  168. nucliadb/search/search/query_parser/parsers/unit_retrieval.py +8 -29
  169. nucliadb/search/search/rank_fusion.py +18 -13
  170. nucliadb/search/search/rerankers.py +5 -6
  171. nucliadb/search/search/retrieval.py +300 -0
  172. nucliadb/search/search/summarize.py +5 -6
  173. nucliadb/search/search/utils.py +3 -4
  174. nucliadb/search/settings.py +1 -2
  175. nucliadb/standalone/api_router.py +1 -1
  176. nucliadb/standalone/app.py +4 -3
  177. nucliadb/standalone/auth.py +5 -6
  178. nucliadb/standalone/lifecycle.py +2 -2
  179. nucliadb/standalone/run.py +2 -4
  180. nucliadb/standalone/settings.py +5 -6
  181. nucliadb/standalone/versions.py +3 -4
  182. nucliadb/tasks/consumer.py +13 -8
  183. nucliadb/tasks/models.py +2 -1
  184. nucliadb/tasks/producer.py +3 -3
  185. nucliadb/tasks/retries.py +8 -7
  186. nucliadb/train/api/utils.py +1 -3
  187. nucliadb/train/api/v1/shards.py +1 -2
  188. nucliadb/train/api/v1/trainset.py +1 -2
  189. nucliadb/train/app.py +1 -1
  190. nucliadb/train/generator.py +4 -4
  191. nucliadb/train/generators/field_classifier.py +2 -2
  192. nucliadb/train/generators/field_streaming.py +6 -6
  193. nucliadb/train/generators/image_classifier.py +2 -2
  194. nucliadb/train/generators/paragraph_classifier.py +2 -2
  195. nucliadb/train/generators/paragraph_streaming.py +2 -2
  196. nucliadb/train/generators/question_answer_streaming.py +2 -2
  197. nucliadb/train/generators/sentence_classifier.py +2 -2
  198. nucliadb/train/generators/token_classifier.py +3 -2
  199. nucliadb/train/generators/utils.py +6 -5
  200. nucliadb/train/nodes.py +3 -3
  201. nucliadb/train/resource.py +6 -8
  202. nucliadb/train/settings.py +3 -4
  203. nucliadb/train/types.py +11 -11
  204. nucliadb/train/upload.py +3 -2
  205. nucliadb/train/uploader.py +1 -2
  206. nucliadb/train/utils.py +1 -2
  207. nucliadb/writer/api/v1/export_import.py +4 -1
  208. nucliadb/writer/api/v1/field.py +7 -11
  209. nucliadb/writer/api/v1/knowledgebox.py +3 -4
  210. nucliadb/writer/api/v1/resource.py +9 -20
  211. nucliadb/writer/api/v1/services.py +10 -132
  212. nucliadb/writer/api/v1/upload.py +73 -72
  213. nucliadb/writer/app.py +8 -2
  214. nucliadb/writer/resource/basic.py +12 -15
  215. nucliadb/writer/resource/field.py +7 -5
  216. nucliadb/writer/resource/origin.py +7 -0
  217. nucliadb/writer/settings.py +2 -3
  218. nucliadb/writer/tus/__init__.py +2 -3
  219. nucliadb/writer/tus/azure.py +1 -3
  220. nucliadb/writer/tus/dm.py +3 -3
  221. nucliadb/writer/tus/exceptions.py +3 -4
  222. nucliadb/writer/tus/gcs.py +5 -6
  223. nucliadb/writer/tus/s3.py +2 -3
  224. nucliadb/writer/tus/storage.py +3 -3
  225. {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/METADATA +9 -10
  226. nucliadb-6.10.0.post5705.dist-info/RECORD +410 -0
  227. nucliadb/common/datamanagers/entities.py +0 -139
  228. nucliadb-6.9.1.post5192.dist-info/RECORD +0 -392
  229. {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/WHEEL +0 -0
  230. {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/entry_points.txt +0 -0
  231. {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,6 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
- from typing import Optional
21
20
 
22
21
  from nidx_protos import nodereader_pb2
23
22
  from nidx_protos.nodereader_pb2 import SearchRequest
@@ -25,34 +24,14 @@ from nidx_protos.nodereader_pb2 import SearchRequest
25
24
  from nucliadb.common.filter_expression import add_and_expression
26
25
  from nucliadb.search.search.filters import translate_label
27
26
  from nucliadb.search.search.metrics import node_features, query_parser_observer
28
- from nucliadb.search.search.query import apply_entities_filter, get_sort_field_proto
27
+ from nucliadb.search.search.query import get_sort_field_proto
29
28
  from nucliadb.search.search.query_parser.models import ParsedQuery, PredictReranker, UnitRetrieval
30
29
  from nucliadb.search.search.query_parser.parsers.graph import parse_path_query
31
- from nucliadb_models.labels import LABEL_HIDDEN, translate_system_to_alias_label
30
+ from nucliadb_models.labels import LABEL_HIDDEN
32
31
  from nucliadb_models.search import SortOrderMap
33
32
  from nucliadb_protos import utils_pb2
34
33
 
35
34
 
36
- @query_parser_observer.wrap({"type": "convert_retrieval_to_proto"})
37
- async def legacy_convert_retrieval_to_proto(
38
- parsed: ParsedQuery,
39
- ) -> tuple[SearchRequest, bool, list[str], Optional[str]]:
40
- converter = _Converter(parsed.retrieval)
41
- request = converter.into_search_request()
42
-
43
- # XXX: legacy values that were returned by QueryParser but not always
44
- # needed. We should find a better abstraction
45
-
46
- incomplete = is_incomplete(parsed.retrieval)
47
- autofilter = converter._autofilter
48
-
49
- rephrased_query = None
50
- if parsed.retrieval.query.semantic:
51
- rephrased_query = await parsed.fetcher.get_rephrased_query()
52
-
53
- return request, incomplete, autofilter, rephrased_query
54
-
55
-
56
35
  @query_parser_observer.wrap({"type": "convert_retrieval_to_proto"})
57
36
  def convert_retrieval_to_proto(retrieval: UnitRetrieval) -> SearchRequest:
58
37
  converter = _Converter(retrieval)
@@ -65,8 +44,6 @@ class _Converter:
65
44
  self.req = nodereader_pb2.SearchRequest()
66
45
  self.retrieval = retrieval
67
46
 
68
- self._autofilter: list[str] = []
69
-
70
47
  def into_search_request(self) -> nodereader_pb2.SearchRequest:
71
48
  """Generate a SearchRequest proto from a retrieval operation."""
72
49
  self._apply_text_queries()
@@ -75,6 +52,7 @@ class _Converter:
75
52
  self._apply_graph_query()
76
53
  self._apply_filters()
77
54
  self._apply_top_k()
55
+
78
56
  return self.req
79
57
 
80
58
  def _apply_text_queries(self) -> None:
@@ -235,10 +213,6 @@ class _Converter:
235
213
  self.req.paragraph_filter.CopyFrom(self.retrieval.filters.paragraph_expression)
236
214
  self.req.filter_operator = self.retrieval.filters.filter_expression_operator
237
215
 
238
- if self.retrieval.filters.autofilter:
239
- entity_filters = apply_entities_filter(self.req, self.retrieval.filters.autofilter)
240
- self._autofilter.extend([translate_system_to_alias_label(e) for e in entity_filters])
241
-
242
216
  if self.retrieval.filters.hidden is not None:
243
217
  expr = nodereader_pb2.FilterExpression()
244
218
  if self.retrieval.filters.hidden:
@@ -281,3 +255,8 @@ def is_incomplete(retrieval: UnitRetrieval) -> bool:
281
255
  return False
282
256
  incomplete = retrieval.query.semantic.query is None or len(retrieval.query.semantic.query) == 0
283
257
  return incomplete
258
+
259
+
260
+ def get_rephrased_query(parsed: ParsedQuery) -> str | None:
261
+ """Given a parsed query, return the rephrased query used, if any."""
262
+ return parsed.fetcher.get_cached_rephrased_query()
@@ -20,11 +20,12 @@
20
20
  import logging
21
21
  from abc import ABC, abstractmethod
22
22
  from enum import Enum, auto
23
- from typing import Optional, TypeVar
23
+ from typing import TypeVar
24
24
 
25
25
  from nucliadb.common.external_index_providers.base import ScoredTextBlock
26
26
  from nucliadb.common.ids import ParagraphId
27
27
  from nucliadb.search.search.query_parser import models as parser_models
28
+ from nucliadb_models.retrieval import RrfScore, Score, WeightedCombSumScore
28
29
  from nucliadb_models.search import SCORE_TYPE
29
30
  from nucliadb_telemetry.metrics import Observer
30
31
 
@@ -127,7 +128,7 @@ class ReciprocalRankFusion(RankFusionAlgorithm):
127
128
  k: float = 60.0,
128
129
  *,
129
130
  window: int,
130
- weights: Optional[dict[str, float]] = None,
131
+ weights: dict[str, float] | None = None,
131
132
  default_weight: float = 1.0,
132
133
  ):
133
134
  super().__init__(window)
@@ -145,7 +146,7 @@ class ReciprocalRankFusion(RankFusionAlgorithm):
145
146
  sources: dict[str, list[ScoredItem]],
146
147
  ) -> list[ScoredItem]:
147
148
  # accumulated scores per paragraph
148
- scores: dict[ParagraphId, tuple[float, SCORE_TYPE]] = {}
149
+ scores: dict[ParagraphId, tuple[float, SCORE_TYPE, list[Score]]] = {}
149
150
  # pointers from paragraph to the original source
150
151
  match_positions: dict[ParagraphId, list[tuple[int, int]]] = {}
151
152
 
@@ -161,11 +162,12 @@ class ReciprocalRankFusion(RankFusionAlgorithm):
161
162
  for i, (ranking, weight) in enumerate(rankings):
162
163
  for rank, item in enumerate(ranking):
163
164
  id = item.paragraph_id
164
- score, score_type = scores.setdefault(id, (0, item.score_type))
165
+ score, score_type, history = scores.setdefault(id, (0, item.score_type, []))
165
166
  score += 1 / (self._k + rank) * weight
167
+ history.append(item.current_score)
166
168
  if {score_type, item.score_type} == {SCORE_TYPE.BM25, SCORE_TYPE.VECTOR}:
167
169
  score_type = SCORE_TYPE.BOTH
168
- scores[id] = (score, score_type)
170
+ scores[id] = (score, score_type, history)
169
171
 
170
172
  position = (i, rank)
171
173
  match_positions.setdefault(item.paragraph_id, []).append(position)
@@ -175,9 +177,10 @@ class ReciprocalRankFusion(RankFusionAlgorithm):
175
177
  # we are getting only one position, effectively deduplicating
176
178
  # multiple matches for the same text block
177
179
  i, j = match_positions[paragraph_id][0]
178
- score, score_type = scores[paragraph_id]
180
+ score, score_type, history = scores[paragraph_id]
179
181
  item = rankings[i][0][j]
180
- item.score = score
182
+ history.append(RrfScore(score=score))
183
+ item.scores = history
181
184
  item.score_type = score_type
182
185
  merged.append(item)
183
186
 
@@ -207,7 +210,7 @@ class WeightedCombSum(RankFusionAlgorithm):
207
210
  self,
208
211
  *,
209
212
  window: int,
210
- weights: Optional[dict[str, float]] = None,
213
+ weights: dict[str, float] | None = None,
211
214
  default_weight: float = 1.0,
212
215
  ):
213
216
  super().__init__(window)
@@ -217,7 +220,7 @@ class WeightedCombSum(RankFusionAlgorithm):
217
220
  @rank_fusion_observer.wrap({"type": "weighted_comb_sum"})
218
221
  def _fuse(self, sources: dict[str, list[ScoredItem]]) -> list[ScoredItem]:
219
222
  # accumulated scores per paragraph
220
- scores: dict[ParagraphId, tuple[float, SCORE_TYPE]] = {}
223
+ scores: dict[ParagraphId, tuple[float, SCORE_TYPE, list[Score]]] = {}
221
224
  # pointers from paragraph to the original source
222
225
  match_positions: dict[ParagraphId, list[tuple[int, int]]] = {}
223
226
 
@@ -228,11 +231,12 @@ class WeightedCombSum(RankFusionAlgorithm):
228
231
  for i, (ranking, weight) in enumerate(rankings):
229
232
  for j, item in enumerate(ranking):
230
233
  id = item.paragraph_id
231
- score, score_type = scores.setdefault(id, (0, item.score_type))
234
+ score, score_type, history = scores.setdefault(id, (0, item.score_type, []))
232
235
  score += item.score * weight
236
+ history.append(item.current_score)
233
237
  if {score_type, item.score_type} == {SCORE_TYPE.BM25, SCORE_TYPE.VECTOR}:
234
238
  score_type = SCORE_TYPE.BOTH
235
- scores[id] = (score, score_type)
239
+ scores[id] = (score, score_type, history)
236
240
 
237
241
  position = (i, j)
238
242
  match_positions.setdefault(item.paragraph_id, []).append(position)
@@ -242,9 +246,10 @@ class WeightedCombSum(RankFusionAlgorithm):
242
246
  # we are getting only one position, effectively deduplicating
243
247
  # multiple matches for the same text block
244
248
  i, j = match_positions[paragraph_id][0]
245
- score, score_type = scores[paragraph_id]
249
+ score, score_type, history = scores[paragraph_id]
246
250
  item = rankings[i][0][j]
247
- item.score = score
251
+ history.append(WeightedCombSumScore(score=score))
252
+ item.scores = history
248
253
  item.score_type = score_type
249
254
  merged.append(item)
250
255
 
@@ -21,7 +21,8 @@
21
21
  import logging
22
22
  from abc import ABC, abstractmethod, abstractproperty
23
23
  from dataclasses import dataclass
24
- from typing import Optional
24
+
25
+ from typing_extensions import assert_never
25
26
 
26
27
  from nucliadb.search.predict import ProxiedPredictAPIError, SendToPredictError
27
28
  from nucliadb.search.search.query_parser import models as parser_models
@@ -63,7 +64,7 @@ class RerankingOptions:
63
64
 
64
65
  class Reranker(ABC):
65
66
  @abstractproperty
66
- def window(self) -> Optional[int]:
67
+ def window(self) -> int | None:
67
68
  """Number of elements the reranker requests. `None` means no specific
68
69
  window is enforced."""
69
70
  ...
@@ -102,7 +103,7 @@ class NoopReranker(Reranker):
102
103
  """
103
104
 
104
105
  @property
105
- def window(self) -> Optional[int]:
106
+ def window(self) -> int | None:
106
107
  return None
107
108
 
108
109
  @reranker_observer.wrap({"type": "noop"})
@@ -182,9 +183,7 @@ def get_reranker(reranker: parser_models.Reranker) -> Reranker:
182
183
  algorithm = PredictReranker(reranker.window)
183
184
 
184
185
  else: # pragma: no cover
185
- # This is a trick so mypy generates an error if this branch can be reached,
186
- # that is, if we are missing some ifs
187
- _a: int = "a"
186
+ assert_never(reranker)
188
187
 
189
188
  return algorithm
190
189
 
@@ -0,0 +1,300 @@
1
+ # Copyright (C) 2021 Bosutech XXI S.L.
2
+ #
3
+ # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
+ # For commercial licensing, contact us at info@nuclia.com.
5
+ #
6
+ # AGPL:
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU Affero General Public License as
9
+ # published by the Free Software Foundation, either version 3 of the
10
+ # License, or (at your option) any later version.
11
+ #
12
+ # This program is distributed in the hope that it will be useful,
13
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
+ # GNU Affero General Public License for more details.
16
+ #
17
+ # You should have received a copy of the GNU Affero General Public License
18
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
+ #
20
+ from collections.abc import Iterable
21
+
22
+ from nidx_protos.nodereader_pb2 import (
23
+ DocumentScored,
24
+ GraphSearchResponse,
25
+ ParagraphResult,
26
+ ParagraphSearchResponse,
27
+ SearchRequest,
28
+ SearchResponse,
29
+ VectorSearchResponse,
30
+ )
31
+
32
+ from nucliadb.common.external_index_providers.base import TextBlockMatch
33
+ from nucliadb.common.ids import ParagraphId, VectorId
34
+ from nucliadb.search import logger
35
+ from nucliadb.search.requesters.utils import Method, nidx_query
36
+ from nucliadb.search.search.metrics import merge_observer, search_observer
37
+ from nucliadb.search.search.query_parser.models import UnitRetrieval
38
+ from nucliadb.search.search.query_parser.parsers.unit_retrieval import convert_retrieval_to_proto
39
+ from nucliadb.search.search.rank_fusion import IndexSource, get_rank_fusion
40
+ from nucliadb_models.retrieval import GraphScore, KeywordScore, SemanticScore
41
+ from nucliadb_models.search import SCORE_TYPE, TextPosition
42
+
43
+ # Constant score given to all graph results until we implement graph scoring
44
+ FAKE_GRAPH_SCORE = 1.0
45
+
46
+
47
+ async def nidx_search(kbid: str, pb_query: SearchRequest) -> tuple[SearchResponse, list[str]]:
48
+ """Wrapper around nidx_query for SEARCH that merges shards results in a
49
+ single response.
50
+
51
+ At some point, nidx will provide this functionality and we'll be able to
52
+ remove this.
53
+
54
+ """
55
+ shards_responses, queried_shards = await nidx_query(kbid, Method.SEARCH, pb_query)
56
+ response = merge_shard_responses(shards_responses)
57
+ return response, queried_shards
58
+
59
+
60
+ @search_observer.wrap({"type": "text_block_search"})
61
+ async def text_block_search(
62
+ kbid: str, retrieval: UnitRetrieval
63
+ ) -> tuple[list[TextBlockMatch], SearchRequest, SearchResponse, list[str]]:
64
+ """Search for text blocks in multiple indexes and return an rank fused view.
65
+
66
+ This search method provides a textual view of the data. For example, given a
67
+ graph query, it will return the text blocks associated with matched
68
+ triplets, not the triplet itself.
69
+
70
+ """
71
+ assert retrieval.rank_fusion is not None, "text block search requries a rank fusion algorithm"
72
+
73
+ pb_query = convert_retrieval_to_proto(retrieval)
74
+ shards_response, queried_shards = await nidx_search(kbid, pb_query)
75
+
76
+ keyword_results = keyword_results_to_text_block_matches(shards_response.paragraph.results)
77
+ semantic_results = semantic_results_to_text_block_matches(shards_response.vector.documents)
78
+ graph_results = graph_results_to_text_block_matches(shards_response.graph)
79
+
80
+ rank_fusion = get_rank_fusion(retrieval.rank_fusion)
81
+ merged_text_blocks = rank_fusion.fuse(
82
+ {
83
+ IndexSource.KEYWORD: keyword_results,
84
+ IndexSource.SEMANTIC: semantic_results,
85
+ IndexSource.GRAPH: graph_results,
86
+ }
87
+ )
88
+
89
+ # cut to the rank fusion window. As we ask each shard and index this window,
90
+ # we'll normally have extra results
91
+ text_blocks = merged_text_blocks[: retrieval.rank_fusion.window]
92
+
93
+ return text_blocks, pb_query, shards_response, queried_shards
94
+
95
+
96
+ @merge_observer.wrap({"type": "shards_responses"})
97
+ def merge_shard_responses(
98
+ responses: list[SearchResponse],
99
+ ) -> SearchResponse:
100
+ """Merge search responses into a single response as if there were no shards
101
+ involved.
102
+
103
+ ATENTION! This is not a complete merge, we are only merging the fields
104
+ needed to compose a /find response.
105
+
106
+ """
107
+ paragraphs = []
108
+ vectors = []
109
+ graphs = []
110
+ for response in responses:
111
+ paragraphs.append(response.paragraph)
112
+ vectors.append(response.vector)
113
+ graphs.append(response.graph)
114
+
115
+ merged = SearchResponse(
116
+ paragraph=merge_shards_keyword_responses(paragraphs),
117
+ vector=merge_shards_semantic_responses(vectors),
118
+ graph=merge_shards_graph_responses(graphs),
119
+ )
120
+ return merged
121
+
122
+
123
+ def merge_shards_keyword_responses(
124
+ keyword_responses: list[ParagraphSearchResponse],
125
+ ) -> ParagraphSearchResponse:
126
+ """Merge keyword (paragraph) search responses into a single response as if
127
+ there were no shards involved.
128
+
129
+ ATENTION! This is not a complete merge, we are only merging the fields
130
+ needed to compose a /find response.
131
+
132
+ """
133
+ merged = ParagraphSearchResponse()
134
+ for response in keyword_responses:
135
+ merged.query = response.query
136
+ merged.next_page = merged.next_page or response.next_page
137
+ merged.total += response.total
138
+ merged.results.extend(response.results)
139
+ merged.ematches.extend(response.ematches)
140
+
141
+ return merged
142
+
143
+
144
+ def merge_shards_semantic_responses(
145
+ semantic_responses: list[VectorSearchResponse],
146
+ ) -> VectorSearchResponse:
147
+ """Merge semantic (vector) search responses into a single response as if
148
+ there were no shards involved.
149
+
150
+ ATENTION! This is not a complete merge, we are only merging the fields
151
+ needed to compose a /find response.
152
+
153
+ """
154
+ merged = VectorSearchResponse()
155
+ for response in semantic_responses:
156
+ merged.documents.extend(response.documents)
157
+
158
+ return merged
159
+
160
+
161
+ def merge_shards_graph_responses(
162
+ graph_responses: list[GraphSearchResponse],
163
+ ):
164
+ merged = GraphSearchResponse()
165
+
166
+ for response in graph_responses:
167
+ nodes_offset = len(merged.nodes)
168
+ relations_offset = len(merged.relations)
169
+
170
+ # paths contain indexes to nodes and relations, we must offset them
171
+ # while merging responses to maintain valid data
172
+ for path in response.graph:
173
+ merged_path = GraphSearchResponse.Path()
174
+ merged_path.CopyFrom(path)
175
+ merged_path.source += nodes_offset
176
+ merged_path.relation += relations_offset
177
+ merged_path.destination += nodes_offset
178
+ merged.graph.append(merged_path)
179
+
180
+ merged.nodes.extend(response.nodes)
181
+ merged.relations.extend(response.relations)
182
+
183
+ return merged
184
+
185
+
186
+ def keyword_result_to_text_block_match(item: ParagraphResult) -> TextBlockMatch:
187
+ fuzzy_result = len(item.matches) > 0
188
+ return TextBlockMatch(
189
+ paragraph_id=ParagraphId.from_string(item.paragraph),
190
+ scores=[KeywordScore(score=item.score.bm25)],
191
+ score_type=SCORE_TYPE.BM25,
192
+ order=0, # NOTE: this will be filled later
193
+ text=None, # NOTE: this will be filled later too
194
+ position=TextPosition(
195
+ page_number=item.metadata.position.page_number,
196
+ index=item.metadata.position.index,
197
+ start=item.start,
198
+ end=item.end,
199
+ start_seconds=[x for x in item.metadata.position.start_seconds],
200
+ end_seconds=[x for x in item.metadata.position.end_seconds],
201
+ ),
202
+ # XXX: we should split labels
203
+ field_labels=[],
204
+ paragraph_labels=list(item.labels),
205
+ fuzzy_search=fuzzy_result,
206
+ is_a_table=item.metadata.representation.is_a_table,
207
+ representation_file=item.metadata.representation.file or None,
208
+ page_with_visual=item.metadata.page_with_visual,
209
+ )
210
+
211
+
212
+ def keyword_results_to_text_block_matches(items: Iterable[ParagraphResult]) -> list[TextBlockMatch]:
213
+ return [keyword_result_to_text_block_match(item) for item in items]
214
+
215
+
216
+ class InvalidDocId(Exception):
217
+ """Raised while parsing an invalid id coming from semantic search"""
218
+
219
+ def __init__(self, invalid_vector_id: str):
220
+ self.invalid_vector_id = invalid_vector_id
221
+ super().__init__(f"Invalid vector ID: {invalid_vector_id}")
222
+
223
+
224
+ def semantic_result_to_text_block_match(item: DocumentScored) -> TextBlockMatch:
225
+ try:
226
+ vector_id = VectorId.from_string(item.doc_id.id)
227
+ except (IndexError, ValueError):
228
+ raise InvalidDocId(item.doc_id.id)
229
+
230
+ return TextBlockMatch(
231
+ paragraph_id=ParagraphId.from_vector_id(vector_id),
232
+ scores=[SemanticScore(score=item.score)],
233
+ score_type=SCORE_TYPE.VECTOR,
234
+ order=0, # NOTE: this will be filled later
235
+ text=None, # NOTE: this will be filled later too
236
+ position=TextPosition(
237
+ page_number=item.metadata.position.page_number,
238
+ index=item.metadata.position.index,
239
+ start=vector_id.vector_start,
240
+ end=vector_id.vector_end,
241
+ start_seconds=[x for x in item.metadata.position.start_seconds],
242
+ end_seconds=[x for x in item.metadata.position.end_seconds],
243
+ ),
244
+ # XXX: we should split labels
245
+ field_labels=[],
246
+ paragraph_labels=list(item.labels),
247
+ fuzzy_search=False, # semantic search doesn't have fuzziness
248
+ is_a_table=item.metadata.representation.is_a_table,
249
+ representation_file=item.metadata.representation.file or None,
250
+ page_with_visual=item.metadata.page_with_visual,
251
+ )
252
+
253
+
254
+ def semantic_results_to_text_block_matches(items: Iterable[DocumentScored]) -> list[TextBlockMatch]:
255
+ text_blocks: list[TextBlockMatch] = []
256
+ for item in items:
257
+ try:
258
+ text_block = semantic_result_to_text_block_match(item)
259
+ except InvalidDocId as exc:
260
+ logger.warning(f"Skipping invalid doc_id: {exc.invalid_vector_id}")
261
+ continue
262
+ text_blocks.append(text_block)
263
+ return text_blocks
264
+
265
+
266
+ def graph_results_to_text_block_matches(item: GraphSearchResponse) -> list[TextBlockMatch]:
267
+ matches = []
268
+ for path in item.graph:
269
+ metadata = path.metadata
270
+
271
+ if not metadata.paragraph_id:
272
+ continue
273
+
274
+ paragraph_id = ParagraphId.from_string(metadata.paragraph_id)
275
+ matches.append(
276
+ TextBlockMatch(
277
+ paragraph_id=paragraph_id,
278
+ scores=[GraphScore(score=FAKE_GRAPH_SCORE)],
279
+ score_type=SCORE_TYPE.RELATION_RELEVANCE,
280
+ order=0, # NOTE: this will be filled later
281
+ text=None, # NOTE: this will be filled later too
282
+ position=TextPosition(
283
+ page_number=0,
284
+ index=0,
285
+ start=paragraph_id.paragraph_start,
286
+ end=paragraph_id.paragraph_end,
287
+ start_seconds=[],
288
+ end_seconds=[],
289
+ ),
290
+ # XXX: we should split labels
291
+ field_labels=[],
292
+ paragraph_labels=[],
293
+ fuzzy_search=False, # TODO: this depends on the query, should we populate it?
294
+ is_a_table=False,
295
+ representation_file="",
296
+ page_with_visual=False,
297
+ )
298
+ )
299
+
300
+ return matches
@@ -18,7 +18,6 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
  import asyncio
21
- from typing import Optional
22
21
 
23
22
  from nucliadb.common import datamanagers
24
23
  from nucliadb.common.maindb.utils import get_driver
@@ -36,7 +35,7 @@ from nucliadb_models.search import (
36
35
  from nucliadb_protos.utils_pb2 import ExtractedText
37
36
  from nucliadb_utils.utilities import get_storage
38
37
 
39
- ExtractedTexts = list[tuple[str, str, Optional[ExtractedText]]]
38
+ ExtractedTexts = list[tuple[str, str, ExtractedText | None]]
40
39
 
41
40
  MAX_GET_EXTRACTED_TEXT_OPS = 20
42
41
 
@@ -46,7 +45,7 @@ class NoResourcesToSummarize(Exception):
46
45
 
47
46
 
48
47
  async def summarize(
49
- kbid: str, request: SummarizeRequest, extra_predict_headers: Optional[dict[str, str]]
48
+ kbid: str, request: SummarizeRequest, extra_predict_headers: dict[str, str] | None
50
49
  ) -> SummarizedResponse:
51
50
  predict_request = SummarizeModel()
52
51
  predict_request.generative_model = request.generative_model
@@ -87,7 +86,7 @@ async def get_extracted_texts(kbid: str, resource_uuids_or_slugs: list[str]) ->
87
86
  if uuid is None:
88
87
  logger.warning(f"Resource {uuid_or_slug} not found in KB", extra={"kbid": kbid})
89
88
  continue
90
- resource_orm = Resource(txn=txn, storage=storage, kb=kb_orm, uuid=uuid)
89
+ resource_orm = Resource(txn=txn, storage=storage, kbid=kbid, uuid=uuid)
91
90
  fields = await resource_orm.get_fields(force=True)
92
91
  for _, field in fields.items():
93
92
  task = asyncio.create_task(get_extracted_text(uuid_or_slug, field, max_tasks))
@@ -115,14 +114,14 @@ async def get_extracted_texts(kbid: str, resource_uuids_or_slugs: list[str]) ->
115
114
 
116
115
  async def get_extracted_text(
117
116
  uuid_or_slug, field: Field, max_operations: asyncio.Semaphore
118
- ) -> tuple[str, str, Optional[ExtractedText]]:
117
+ ) -> tuple[str, str, ExtractedText | None]:
119
118
  async with max_operations:
120
119
  extracted_text = await field.get_extracted_text(force=True)
121
120
  field_key = f"{field.type}/{field.id}"
122
121
  return uuid_or_slug, field_key, extracted_text
123
122
 
124
123
 
125
- async def get_resource_uuid(kbobj: KnowledgeBox, uuid_or_slug: str) -> Optional[str]:
124
+ async def get_resource_uuid(kbobj: KnowledgeBox, uuid_or_slug: str) -> str | None:
126
125
  """
127
126
  Return the uuid of the resource with the given uuid_or_slug.
128
127
  """
@@ -18,7 +18,6 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
  import logging
21
- from typing import Optional
22
21
 
23
22
  from pydantic import BaseModel
24
23
 
@@ -30,7 +29,7 @@ from nucliadb_utils.utilities import has_feature
30
29
  logger = logging.getLogger(__name__)
31
30
 
32
31
 
33
- async def filter_hidden_resources(kbid: str, show_hidden: bool) -> Optional[bool]:
32
+ async def filter_hidden_resources(kbid: str, show_hidden: bool) -> bool | None:
34
33
  kb_config = await kb.get_config(kbid=kbid)
35
34
  hidden_enabled = kb_config and kb_config.hidden_resources_enabled
36
35
  if hidden_enabled and not show_hidden:
@@ -41,8 +40,8 @@ async def filter_hidden_resources(kbid: str, show_hidden: bool) -> Optional[bool
41
40
 
42
41
  def min_score_from_query_params(
43
42
  min_score_bm25: float,
44
- min_score_semantic: Optional[float],
45
- deprecated_min_score: Optional[float],
43
+ min_score_semantic: float | None,
44
+ deprecated_min_score: float | None,
46
45
  ) -> MinScore:
47
46
  # Keep backward compatibility with the deprecated min_score parameter
48
47
  semantic = deprecated_min_score if min_score_semantic is None else min_score_semantic
@@ -18,7 +18,6 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
 
21
- from typing import Optional
22
21
 
23
22
  from pydantic import Field
24
23
 
@@ -43,7 +42,7 @@ class Settings(DriverSettings):
43
42
  title="Prequeries max parallel",
44
43
  description="The maximum number of prequeries to run in parallel per /ask request",
45
44
  )
46
- nidx_address: Optional[str] = Field(default=None)
45
+ nidx_address: str | None = Field(default=None)
47
46
 
48
47
 
49
48
  settings = Settings()
@@ -57,7 +57,7 @@ async def api_config_check(request: Request):
57
57
  valid_nua_key = True
58
58
  except Exception as exc:
59
59
  logger.warning(f"Error validating nua key", exc_info=exc)
60
- nua_key_check_error = f"Error checking NUA key: {str(exc)}"
60
+ nua_key_check_error = f"Error checking NUA key: {exc!s}"
61
61
  return JSONResponse(
62
62
  {
63
63
  "nua_api_key": {
@@ -31,7 +31,7 @@ from starlette.responses import HTMLResponse
31
31
  from starlette.routing import Mount
32
32
 
33
33
  import nucliadb_admin_assets # type: ignore
34
- from nucliadb.middleware import ProcessTimeHeaderMiddleware
34
+ from nucliadb.middleware import ClientErrorPayloadLoggerMiddleware, ProcessTimeHeaderMiddleware
35
35
  from nucliadb.reader import API_PREFIX
36
36
  from nucliadb.reader.api.v1.router import api as api_reader_v1
37
37
  from nucliadb.search.api.v1.router import api as api_search_v1
@@ -79,7 +79,7 @@ HOMEPAGE_HTML = """
79
79
  </ul>
80
80
  </body>
81
81
  </html>
82
- """ # noqa: E501
82
+ """
83
83
 
84
84
 
85
85
  def application_factory(settings: Settings) -> FastAPI:
@@ -95,13 +95,13 @@ def application_factory(settings: Settings) -> FastAPI:
95
95
  backend=get_auth_backend(settings),
96
96
  ),
97
97
  Middleware(AuditMiddleware, audit_utility_getter=get_audit),
98
+ Middleware(ClientErrorPayloadLoggerMiddleware),
98
99
  ]
99
100
  if running_settings.debug:
100
101
  middleware.append(Middleware(ProcessTimeHeaderMiddleware))
101
102
 
102
103
  fastapi_settings = dict(
103
104
  debug=running_settings.debug,
104
- middleware=middleware,
105
105
  lifespan=lifespan,
106
106
  exception_handlers={
107
107
  Exception: global_exception_handler,
@@ -122,6 +122,7 @@ def application_factory(settings: Settings) -> FastAPI:
122
122
  prefix_format=f"/{API_PREFIX}/v{{major}}",
123
123
  default_version=(1, 0),
124
124
  enable_latest=False,
125
+ middleware=middleware,
125
126
  kwargs=fastapi_settings,
126
127
  )
127
128