nucliadb 6.4.0.post4200__py3-none-any.whl → 6.4.0.post4210__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,20 @@
1
+ # Copyright (C) 2021 Bosutech XXI S.L.
2
+ #
3
+ # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
+ # For commercial licensing, contact us at info@nuclia.com.
5
+ #
6
+ # AGPL:
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU Affero General Public License as
9
+ # published by the Free Software Foundation, either version 3 of the
10
+ # License, or (at your option) any later version.
11
+ #
12
+ # This program is distributed in the hope that it will be useful,
13
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
+ # GNU Affero General Public License for more details.
16
+ #
17
+ # You should have received a copy of the GNU Affero General Public License
18
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
+ #
20
+ from .materializer import maybe_back_pressure, start_materializer, stop_materializer # noqa
@@ -0,0 +1,86 @@
1
+ # Copyright (C) 2021 Bosutech XXI S.L.
2
+ #
3
+ # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
+ # For commercial licensing, contact us at info@nuclia.com.
5
+ #
6
+ # AGPL:
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU Affero General Public License as
9
+ # published by the Free Software Foundation, either version 3 of the
10
+ # License, or (at your option) any later version.
11
+ #
12
+ # This program is distributed in the hope that it will be useful,
13
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
+ # GNU Affero General Public License for more details.
16
+ #
17
+ # You should have received a copy of the GNU Affero General Public License
18
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
+ #
20
+ import contextlib
21
+ import logging
22
+ import threading
23
+ from datetime import datetime, timezone
24
+ from typing import Optional
25
+
26
+ from cachetools import TTLCache
27
+
28
+ from nucliadb.common.back_pressure.utils import BackPressureData, BackPressureException
29
+ from nucliadb_telemetry import metrics
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ RATE_LIMITED_REQUESTS_COUNTER = metrics.Counter(
35
+ "nucliadb_rate_limited_requests", labels={"type": "", "cached": ""}
36
+ )
37
+
38
+
39
+ class BackPressureCache:
40
+ """
41
+ Global cache for storing already computed try again in times.
42
+ It allows us to avoid making the same calculations multiple
43
+ times if back pressure has been applied.
44
+ """
45
+
46
+ def __init__(self):
47
+ self._cache = TTLCache(maxsize=1024, ttl=5 * 60)
48
+ self._lock = threading.Lock()
49
+
50
+ def get(self, key: str) -> Optional[BackPressureData]:
51
+ with self._lock:
52
+ data = self._cache.get(key, None)
53
+ if data is None:
54
+ return None
55
+ if datetime.now(timezone.utc) >= data.try_after:
56
+ # The key has expired, so remove it from the cache
57
+ self._cache.pop(key, None)
58
+ return None
59
+ return data
60
+
61
+ def set(self, key: str, data: BackPressureData):
62
+ with self._lock:
63
+ self._cache[key] = data
64
+
65
+
66
+ _cache = BackPressureCache()
67
+
68
+
69
+ @contextlib.contextmanager
70
+ def cached_back_pressure(cache_key: str):
71
+ """
72
+ Context manager that handles the caching of the try again in time so that
73
+ we don't recompute try again times if we have already applied back pressure.
74
+ """
75
+ data: Optional[BackPressureData] = _cache.get(cache_key)
76
+ if data is not None:
77
+ back_pressure_type = data.type
78
+ RATE_LIMITED_REQUESTS_COUNTER.inc({"type": back_pressure_type, "cached": "true"})
79
+ raise BackPressureException(data)
80
+ try:
81
+ yield
82
+ except BackPressureException as exc:
83
+ back_pressure_type = exc.data.type
84
+ RATE_LIMITED_REQUESTS_COUNTER.inc({"type": back_pressure_type, "cached": "false"})
85
+ _cache.set(cache_key, exc.data)
86
+ raise exc
@@ -0,0 +1,315 @@
1
+ # Copyright (C) 2021 Bosutech XXI S.L.
2
+ #
3
+ # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
+ # For commercial licensing, contact us at info@nuclia.com.
5
+ #
6
+ # AGPL:
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU Affero General Public License as
9
+ # published by the Free Software Foundation, either version 3 of the
10
+ # License, or (at your option) any later version.
11
+ #
12
+ # This program is distributed in the hope that it will be useful,
13
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
+ # GNU Affero General Public License for more details.
16
+ #
17
+ # You should have received a copy of the GNU Affero General Public License
18
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
+ #
20
+ import asyncio
21
+ import logging
22
+ import threading
23
+ from typing import Optional
24
+
25
+ from cachetools import TTLCache
26
+ from fastapi import HTTPException
27
+
28
+ from nucliadb.common.back_pressure.cache import cached_back_pressure
29
+ from nucliadb.common.back_pressure.settings import settings
30
+ from nucliadb.common.back_pressure.utils import (
31
+ BackPressureData,
32
+ BackPressureException,
33
+ estimate_try_after,
34
+ get_nats_consumer_pending_messages,
35
+ is_back_pressure_enabled,
36
+ )
37
+ from nucliadb.common.context import ApplicationContext
38
+ from nucliadb.common.http_clients.processing import ProcessingHTTPClient
39
+ from nucliadb_telemetry import metrics
40
+ from nucliadb_utils import const
41
+ from nucliadb_utils.nats import NatsConnectionManager
42
+ from nucliadb_utils.settings import is_onprem_nucliadb
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ back_pressure_observer = metrics.Observer("nucliadb_back_pressure", labels={"type": ""})
48
+
49
+
50
+ class BackPressureMaterializer:
51
+ """
52
+ Singleton class that will run in the background gathering the different
53
+ stats to apply back pressure and materializing it in memory. This allows us
54
+ to do stale-reads when checking if back pressure is needed for a particular
55
+ request - thus not slowing it down.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ nats_manager: NatsConnectionManager,
61
+ indexing_check_interval: int = 30,
62
+ ingest_check_interval: int = 30,
63
+ ):
64
+ self.nats_manager = nats_manager
65
+ self.processing_http_client = ProcessingHTTPClient()
66
+
67
+ self.indexing_check_interval = indexing_check_interval
68
+ self.ingest_check_interval = ingest_check_interval
69
+
70
+ self.ingest_pending: int = 0
71
+ self.indexing_pending: int = 0
72
+
73
+ self._tasks: list[asyncio.Task] = []
74
+ self._running = False
75
+
76
+ self.processing_pending_cache = TTLCache(maxsize=1024, ttl=60) # type: ignore
77
+ self.processing_pending_locks: dict[str, asyncio.Lock] = {}
78
+
79
+ async def start(self):
80
+ self._tasks.append(asyncio.create_task(self._get_indexing_pending_task()))
81
+ self._tasks.append(asyncio.create_task(self._get_ingest_pending_task()))
82
+ self._running = True
83
+
84
+ async def stop(self):
85
+ for task in self._tasks:
86
+ task.cancel()
87
+ self._tasks.clear()
88
+ await self.processing_http_client.close()
89
+ self._running = False
90
+
91
+ @property
92
+ def running(self) -> bool:
93
+ return self._running
94
+
95
+ async def get_processing_pending(self, kbid: str) -> int:
96
+ """
97
+ We don't materialize the pending messages for every kbid, but values are cached for some time.
98
+ """
99
+ cached = self.processing_pending_cache.get(kbid)
100
+ if cached is not None:
101
+ return cached
102
+
103
+ lock = self.processing_pending_locks.setdefault(kbid, asyncio.Lock())
104
+ async with lock:
105
+ # Check again if the value has been cached while we were waiting for the lock
106
+ cached = self.processing_pending_cache.get(kbid)
107
+ if cached is not None:
108
+ return cached
109
+
110
+ # Get the pending messages and cache the result
111
+ try:
112
+ with back_pressure_observer({"type": "get_processing_pending"}):
113
+ pending = await self._get_processing_pending(kbid)
114
+ except Exception: # pragma: no cover
115
+ # Do not cache if there was an error
116
+ logger.exception(
117
+ "Error getting pending messages to process. Back pressure on proccessing for KB can't be applied.",
118
+ exc_info=True,
119
+ extra={"kbid": kbid},
120
+ )
121
+ return 0
122
+
123
+ if pending > 0:
124
+ logger.info(
125
+ f"Processing returned {pending} pending messages for KB",
126
+ extra={"kbid": kbid},
127
+ )
128
+ self.processing_pending_cache[kbid] = pending
129
+ return pending
130
+
131
+ async def _get_processing_pending(self, kbid: str) -> int:
132
+ response = await self.processing_http_client.stats(kbid=kbid, timeout=0.5)
133
+ return response.incomplete
134
+
135
+ def get_indexing_pending(self) -> int:
136
+ return self.indexing_pending
137
+
138
+ def get_ingest_pending(self) -> int:
139
+ return self.ingest_pending
140
+
141
+ async def _get_indexing_pending_task(self):
142
+ try:
143
+ while True:
144
+ try:
145
+ with back_pressure_observer({"type": "get_indexing_pending"}):
146
+ self.indexing_pending = await get_nats_consumer_pending_messages(
147
+ self.nats_manager,
148
+ stream="nidx",
149
+ consumer="nidx",
150
+ )
151
+ except Exception: # pragma: no cover
152
+ logger.exception(
153
+ "Error getting pending messages to index",
154
+ exc_info=True,
155
+ )
156
+ await asyncio.sleep(self.indexing_check_interval)
157
+ except asyncio.CancelledError:
158
+ pass
159
+
160
+ async def _get_ingest_pending_task(self):
161
+ try:
162
+ while True:
163
+ try:
164
+ with back_pressure_observer({"type": "get_ingest_pending"}):
165
+ self.ingest_pending = await get_nats_consumer_pending_messages(
166
+ self.nats_manager,
167
+ stream=const.Streams.INGEST_PROCESSED.name,
168
+ consumer=const.Streams.INGEST_PROCESSED.group,
169
+ )
170
+ except Exception: # pragma: no cover
171
+ logger.exception(
172
+ "Error getting pending messages to ingest",
173
+ exc_info=True,
174
+ )
175
+ await asyncio.sleep(self.ingest_check_interval)
176
+ except asyncio.CancelledError:
177
+ pass
178
+
179
+ def check_indexing(self):
180
+ max_pending = settings.max_indexing_pending
181
+ if max_pending <= 0:
182
+ # Indexing back pressure is disabled
183
+ return
184
+ pending = self.get_indexing_pending()
185
+ if pending > max_pending:
186
+ try_after = estimate_try_after(
187
+ rate=settings.indexing_rate,
188
+ pending=pending,
189
+ max_wait=settings.max_wait_time,
190
+ )
191
+ data = BackPressureData(type="indexing", try_after=try_after)
192
+ raise BackPressureException(data)
193
+
194
+ def check_ingest(self):
195
+ max_pending = settings.max_ingest_pending
196
+ if max_pending <= 0:
197
+ # Ingest back pressure is disabled
198
+ return
199
+ ingest_pending = self.get_ingest_pending()
200
+ if ingest_pending > max_pending:
201
+ try_after = estimate_try_after(
202
+ rate=settings.ingest_rate,
203
+ pending=ingest_pending,
204
+ max_wait=settings.max_wait_time,
205
+ )
206
+ data = BackPressureData(type="ingest", try_after=try_after)
207
+ raise BackPressureException(data)
208
+
209
+ async def check_processing(self, kbid: str):
210
+ max_pending = settings.max_processing_pending
211
+ if max_pending <= 0:
212
+ # Processing back pressure is disabled
213
+ return
214
+
215
+ kb_pending = await self.get_processing_pending(kbid)
216
+ if kb_pending > max_pending:
217
+ try_after = estimate_try_after(
218
+ rate=settings.processing_rate,
219
+ pending=kb_pending,
220
+ max_wait=settings.max_wait_time,
221
+ )
222
+ data = BackPressureData(type="processing", try_after=try_after)
223
+ raise BackPressureException(data)
224
+
225
+
226
+ MATERIALIZER: Optional[BackPressureMaterializer] = None
227
+ materializer_lock = threading.Lock()
228
+
229
+
230
+ async def start_materializer(context: ApplicationContext):
231
+ global MATERIALIZER
232
+ if MATERIALIZER is not None:
233
+ logger.warning("BackPressureMaterializer already started")
234
+ return
235
+ with materializer_lock:
236
+ if MATERIALIZER is not None:
237
+ return
238
+ logger.info("Initializing materializer")
239
+ try:
240
+ nats_manager = context.nats_manager
241
+ except AttributeError:
242
+ logger.warning(
243
+ "Could not initialize materializer. Nats manager not found or not initialized yet"
244
+ )
245
+ return
246
+ materializer = BackPressureMaterializer(
247
+ nats_manager,
248
+ indexing_check_interval=settings.indexing_check_interval,
249
+ ingest_check_interval=settings.ingest_check_interval,
250
+ )
251
+ await materializer.start()
252
+ MATERIALIZER = materializer
253
+
254
+
255
+ async def stop_materializer():
256
+ global MATERIALIZER
257
+ if MATERIALIZER is None or not MATERIALIZER.running:
258
+ logger.warning("BackPressureMaterializer already stopped")
259
+ return
260
+ with materializer_lock:
261
+ if MATERIALIZER is None:
262
+ return
263
+ logger.info("Stopping materializer")
264
+ await MATERIALIZER.stop()
265
+ MATERIALIZER = None
266
+
267
+
268
+ def get_materializer() -> BackPressureMaterializer:
269
+ global MATERIALIZER
270
+ if MATERIALIZER is None:
271
+ raise RuntimeError("BackPressureMaterializer not initialized")
272
+ return MATERIALIZER
273
+
274
+
275
+ async def maybe_back_pressure(kbid: str, resource_uuid: Optional[str] = None) -> None:
276
+ """
277
+ This function does system checks to see if we need to put back pressure on writes.
278
+ In that case, a HTTP 429 will be raised with the estimated time to try again.
279
+ """
280
+ if not is_back_pressure_enabled() or is_onprem_nucliadb():
281
+ return
282
+ await back_pressure_checks(kbid, resource_uuid)
283
+
284
+
285
+ async def back_pressure_checks(kbid: str, resource_uuid: Optional[str] = None):
286
+ """
287
+ Will raise a 429 if back pressure is needed:
288
+ - If the processing engine is behind.
289
+ - If ingest processed consumer is behind.
290
+ - If the indexing on nodes affected by the request (kbid, and resource_uuid) is behind.
291
+ """
292
+ materializer = get_materializer()
293
+ try:
294
+ with cached_back_pressure(f"{kbid}-{resource_uuid}"):
295
+ materializer.check_indexing()
296
+ materializer.check_ingest()
297
+ await materializer.check_processing(kbid)
298
+ except BackPressureException as exc:
299
+ logger.info(
300
+ "Back pressure applied",
301
+ extra={
302
+ "kbid": kbid,
303
+ "resource_uuid": resource_uuid,
304
+ "try_after": exc.data.try_after,
305
+ "back_pressure_type": exc.data.type,
306
+ },
307
+ )
308
+ raise HTTPException(
309
+ status_code=429,
310
+ detail={
311
+ "message": f"Too many messages pending to ingest. Retry after {exc.data.try_after}",
312
+ "try_after": exc.data.try_after.timestamp(),
313
+ "back_pressure_type": exc.data.type,
314
+ },
315
+ ) from exc
@@ -0,0 +1,72 @@
1
+ # Copyright (C) 2021 Bosutech XXI S.L.
2
+ #
3
+ # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
+ # For commercial licensing, contact us at info@nuclia.com.
5
+ #
6
+ # AGPL:
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU Affero General Public License as
9
+ # published by the Free Software Foundation, either version 3 of the
10
+ # License, or (at your option) any later version.
11
+ #
12
+ # This program is distributed in the hope that it will be useful,
13
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
+ # GNU Affero General Public License for more details.
16
+ #
17
+ # You should have received a copy of the GNU Affero General Public License
18
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
+ #
20
+ from pydantic import Field
21
+ from pydantic_settings import BaseSettings
22
+
23
+
24
+ class BackPressureSettings(BaseSettings):
25
+ enabled: bool = Field(
26
+ default=False,
27
+ description="Enable or disable back pressure.",
28
+ alias="back_pressure_enabled",
29
+ )
30
+ indexing_rate: float = Field(
31
+ default=10,
32
+ description="Estimation of the indexing rate in messages per second. This is used to calculate the try again in time", # noqa
33
+ )
34
+ ingest_rate: float = Field(
35
+ default=4,
36
+ description="Estimation of the ingest processed consumer rate in messages per second. This is used to calculate the try again in time", # noqa
37
+ )
38
+ processing_rate: float = Field(
39
+ default=1,
40
+ description="Estimation of the processing rate in messages per second. This is used to calculate the try again in time", # noqa
41
+ )
42
+ max_indexing_pending: int = Field(
43
+ default=1000,
44
+ description="Max number of messages pending to index in a node queue before rate limiting writes. Set to 0 to disable indexing back pressure checks", # noqa
45
+ alias="back_pressure_max_indexing_pending",
46
+ )
47
+ max_ingest_pending: int = Field(
48
+ # Disabled by default
49
+ default=0,
50
+ description="Max number of messages pending to be ingested by processed consumers before rate limiting writes. Set to 0 to disable ingest back pressure checks", # noqa
51
+ alias="back_pressure_max_ingest_pending",
52
+ )
53
+ max_processing_pending: int = Field(
54
+ default=1000,
55
+ description="Max number of messages pending to process per Knowledge Box before rate limiting writes. Set to 0 to disable processing back pressure checks", # noqa
56
+ alias="back_pressure_max_processing_pending",
57
+ )
58
+ indexing_check_interval: int = Field(
59
+ default=30,
60
+ description="Interval in seconds to check the indexing pending messages",
61
+ )
62
+ ingest_check_interval: int = Field(
63
+ default=30,
64
+ description="Interval in seconds to check the ingest pending messages",
65
+ )
66
+ max_wait_time: int = Field(
67
+ default=60,
68
+ description="Max time in seconds to wait before trying again after back pressure",
69
+ )
70
+
71
+
72
+ settings = BackPressureSettings()
@@ -0,0 +1,59 @@
1
+ # Copyright (C) 2021 Bosutech XXI S.L.
2
+ #
3
+ # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
+ # For commercial licensing, contact us at info@nuclia.com.
5
+ #
6
+ # AGPL:
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU Affero General Public License as
9
+ # published by the Free Software Foundation, either version 3 of the
10
+ # License, or (at your option) any later version.
11
+ #
12
+ # This program is distributed in the hope that it will be useful,
13
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
+ # GNU Affero General Public License for more details.
16
+ #
17
+ # You should have received a copy of the GNU Affero General Public License
18
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
+ #
20
+ from dataclasses import dataclass
21
+ from datetime import datetime, timedelta, timezone
22
+
23
+ from nucliadb.common.back_pressure.settings import settings
24
+ from nucliadb_utils.nats import NatsConnectionManager
25
+
26
+
27
+ @dataclass
28
+ class BackPressureData:
29
+ type: str
30
+ try_after: datetime
31
+
32
+
33
+ class BackPressureException(Exception):
34
+ def __init__(self, data: BackPressureData):
35
+ self.data = data
36
+
37
+ def __str__(self):
38
+ return f"Back pressure applied for {self.data.type}. Try again after {self.data.try_after}"
39
+
40
+
41
+ def is_back_pressure_enabled() -> bool:
42
+ return settings.enabled
43
+
44
+
45
+ def estimate_try_after(rate: float, pending: int, max_wait: int) -> datetime:
46
+ """
47
+ This function estimates the time to try again based on the rate and the number of pending messages.
48
+ """
49
+ delta_seconds = min(pending / rate, max_wait)
50
+ return datetime.now(timezone.utc) + timedelta(seconds=delta_seconds)
51
+
52
+
53
+ async def get_nats_consumer_pending_messages(
54
+ nats_manager: NatsConnectionManager, *, stream: str, consumer: str
55
+ ) -> int:
56
+ # get raw js client
57
+ js = nats_manager.js
58
+ consumer_info = await js.consumer_info(stream, consumer)
59
+ return consumer_info.num_pending
@@ -55,16 +55,19 @@ class VectorsetExternalIndex:
55
55
  similarity: VectorSimilarity.ValueType
56
56
 
57
57
 
58
- class TextBlockMatch(BaseModel):
58
+ class ScoredTextBlock(BaseModel):
59
+ paragraph_id: ParagraphId
60
+ score: float
61
+ score_type: SCORE_TYPE
62
+
63
+
64
+ class TextBlockMatch(ScoredTextBlock):
59
65
  """
60
66
  Model a text block/paragraph retrieved from an external index with all the information
61
67
  needed in order to later hydrate retrieval results.
62
68
  """
63
69
 
64
- paragraph_id: ParagraphId
65
70
  position: TextPosition
66
- score: float
67
- score_type: SCORE_TYPE
68
71
  order: int
69
72
  page_with_visual: bool = False
70
73
  fuzzy_search: bool
@@ -33,6 +33,8 @@ from nuclia_models.predict.generative_responses import (
33
33
  from pydantic_core import ValidationError
34
34
 
35
35
  from nucliadb.common.datamanagers.exceptions import KnowledgeBoxNotFound
36
+ from nucliadb.common.external_index_providers.base import ScoredTextBlock
37
+ from nucliadb.common.ids import ParagraphId
36
38
  from nucliadb.models.responses import HTTPClientError
37
39
  from nucliadb.search import logger, predict
38
40
  from nucliadb.search.predict import (
@@ -63,6 +65,7 @@ from nucliadb.search.search.graph_strategy import get_graph_results
63
65
  from nucliadb.search.search.metrics import RAGMetrics
64
66
  from nucliadb.search.search.query_parser.fetcher import Fetcher
65
67
  from nucliadb.search.search.query_parser.parsers.ask import fetcher_for_ask, parse_ask
68
+ from nucliadb.search.search.rank_fusion import WeightedCombSum
66
69
  from nucliadb.search.search.rerankers import (
67
70
  get_reranker,
68
71
  )
@@ -865,6 +868,10 @@ async def retrieval_in_resource(
865
868
  )
866
869
 
867
870
 
871
+ class _FindParagraph(ScoredTextBlock):
872
+ original: FindParagraph
873
+
874
+
868
875
  def compute_best_matches(
869
876
  main_results: KnowledgeboxFindResults,
870
877
  prequeries_results: Optional[list[PreQueryResult]] = None,
@@ -882,42 +889,46 @@ def compute_best_matches(
882
889
  `main_query_weight` is the weight given to the paragraphs matching the main query when calculating the final score.
883
890
  """
884
891
 
885
- def iter_paragraphs(results: KnowledgeboxFindResults):
892
+ def extract_paragraphs(results: KnowledgeboxFindResults) -> list[_FindParagraph]:
893
+ paragraphs = []
886
894
  for resource in results.resources.values():
887
895
  for field in resource.fields.values():
888
896
  for paragraph in field.paragraphs.values():
889
- yield paragraph
890
-
891
- total_weights = main_query_weight + sum(prequery.weight for prequery, _ in prequeries_results or [])
892
- paragraph_id_to_match: dict[str, RetrievalMatch] = {}
893
- for paragraph in iter_paragraphs(main_results):
894
- normalized_weight = main_query_weight / total_weights
895
- rmatch = RetrievalMatch(
896
- paragraph=paragraph,
897
- weighted_score=paragraph.score * normalized_weight,
898
- )
899
- paragraph_id_to_match[paragraph.id] = rmatch
900
-
901
- for prequery, prequery_results in prequeries_results or []:
902
- for paragraph in iter_paragraphs(prequery_results):
903
- normalized_weight = prequery.weight / total_weights
904
- weighted_score = paragraph.score * normalized_weight
905
- if paragraph.id in paragraph_id_to_match:
906
- rmatch = paragraph_id_to_match[paragraph.id]
907
- # If a paragraph is matched in various prequeries, the final score is the
908
- # sum of the weighted scores
909
- rmatch.weighted_score += weighted_score
910
- else:
911
- paragraph_id_to_match[paragraph.id] = RetrievalMatch(
912
- paragraph=paragraph,
913
- weighted_score=weighted_score,
914
- )
897
+ paragraphs.append(
898
+ _FindParagraph(
899
+ paragraph_id=ParagraphId.from_string(paragraph.id),
900
+ score=paragraph.score,
901
+ score_type=paragraph.score_type,
902
+ original=paragraph,
903
+ )
904
+ )
905
+ return paragraphs
915
906
 
916
- return sorted(
917
- paragraph_id_to_match.values(),
918
- key=lambda match: match.weighted_score,
919
- reverse=True,
920
- )
907
+ weights = {
908
+ "main": main_query_weight,
909
+ }
910
+ total_weight = main_query_weight
911
+ find_results = {
912
+ "main": extract_paragraphs(main_results),
913
+ }
914
+ for i, (prequery, prequery_results) in enumerate(prequeries_results or []):
915
+ weights[f"prequery-{i}"] = prequery.weight
916
+ total_weight += prequery.weight
917
+ find_results[f"prequery-{i}"] = extract_paragraphs(prequery_results)
918
+
919
+ normalized_weights = {key: value / total_weight for key, value in weights.items()}
920
+
921
+ # window does nothing here
922
+ rank_fusion = WeightedCombSum(window=0, weights=normalized_weights)
923
+
924
+ merged = []
925
+ for item in rank_fusion.fuse(find_results):
926
+ match = RetrievalMatch(
927
+ paragraph=item.original,
928
+ weighted_score=item.score,
929
+ )
930
+ merged.append(match)
931
+ return merged
921
932
 
922
933
 
923
934
  def calculate_prequeries_for_json_schema(