nucliadb 6.9.1.post5192__py3-none-any.whl → 6.10.0.post5705__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migrations/0023_backfill_pg_catalog.py +2 -2
- migrations/0029_backfill_field_status.py +3 -4
- migrations/0032_remove_old_relations.py +2 -3
- migrations/0038_backfill_catalog_field_labels.py +2 -2
- migrations/0039_backfill_converation_splits_metadata.py +2 -2
- migrations/0041_reindex_conversations.py +137 -0
- migrations/pg/0010_shards_index.py +34 -0
- nucliadb/search/api/v1/resource/utils.py → migrations/pg/0011_catalog_statistics.py +5 -6
- migrations/pg/0012_catalog_statistics_undo.py +26 -0
- nucliadb/backups/create.py +2 -15
- nucliadb/backups/restore.py +4 -15
- nucliadb/backups/tasks.py +4 -1
- nucliadb/common/back_pressure/cache.py +2 -3
- nucliadb/common/back_pressure/materializer.py +7 -13
- nucliadb/common/back_pressure/settings.py +6 -6
- nucliadb/common/back_pressure/utils.py +1 -0
- nucliadb/common/cache.py +9 -9
- nucliadb/common/catalog/interface.py +12 -12
- nucliadb/common/catalog/pg.py +41 -29
- nucliadb/common/catalog/utils.py +3 -3
- nucliadb/common/cluster/manager.py +5 -4
- nucliadb/common/cluster/rebalance.py +483 -114
- nucliadb/common/cluster/rollover.py +25 -9
- nucliadb/common/cluster/settings.py +3 -8
- nucliadb/common/cluster/utils.py +34 -8
- nucliadb/common/context/__init__.py +7 -8
- nucliadb/common/context/fastapi.py +1 -2
- nucliadb/common/datamanagers/__init__.py +2 -4
- nucliadb/common/datamanagers/atomic.py +4 -2
- nucliadb/common/datamanagers/cluster.py +1 -2
- nucliadb/common/datamanagers/fields.py +3 -4
- nucliadb/common/datamanagers/kb.py +6 -6
- nucliadb/common/datamanagers/labels.py +2 -3
- nucliadb/common/datamanagers/resources.py +10 -33
- nucliadb/common/datamanagers/rollover.py +5 -7
- nucliadb/common/datamanagers/search_configurations.py +1 -2
- nucliadb/common/datamanagers/synonyms.py +1 -2
- nucliadb/common/datamanagers/utils.py +4 -4
- nucliadb/common/datamanagers/vectorsets.py +4 -4
- nucliadb/common/external_index_providers/base.py +32 -5
- nucliadb/common/external_index_providers/manager.py +4 -5
- nucliadb/common/filter_expression.py +128 -40
- nucliadb/common/http_clients/processing.py +12 -23
- nucliadb/common/ids.py +6 -4
- nucliadb/common/locking.py +1 -2
- nucliadb/common/maindb/driver.py +9 -8
- nucliadb/common/maindb/local.py +5 -5
- nucliadb/common/maindb/pg.py +9 -8
- nucliadb/common/nidx.py +3 -4
- nucliadb/export_import/datamanager.py +4 -3
- nucliadb/export_import/exporter.py +11 -19
- nucliadb/export_import/importer.py +13 -6
- nucliadb/export_import/tasks.py +2 -0
- nucliadb/export_import/utils.py +6 -18
- nucliadb/health.py +2 -2
- nucliadb/ingest/app.py +8 -8
- nucliadb/ingest/consumer/consumer.py +8 -10
- nucliadb/ingest/consumer/pull.py +3 -8
- nucliadb/ingest/consumer/service.py +3 -3
- nucliadb/ingest/consumer/utils.py +1 -1
- nucliadb/ingest/fields/base.py +28 -49
- nucliadb/ingest/fields/conversation.py +12 -12
- nucliadb/ingest/fields/exceptions.py +1 -2
- nucliadb/ingest/fields/file.py +22 -8
- nucliadb/ingest/fields/link.py +7 -7
- nucliadb/ingest/fields/text.py +2 -3
- nucliadb/ingest/orm/brain_v2.py +78 -64
- nucliadb/ingest/orm/broker_message.py +2 -4
- nucliadb/ingest/orm/entities.py +10 -209
- nucliadb/ingest/orm/index_message.py +4 -4
- nucliadb/ingest/orm/knowledgebox.py +18 -27
- nucliadb/ingest/orm/processor/auditing.py +1 -3
- nucliadb/ingest/orm/processor/data_augmentation.py +1 -2
- nucliadb/ingest/orm/processor/processor.py +27 -27
- nucliadb/ingest/orm/processor/sequence_manager.py +1 -2
- nucliadb/ingest/orm/resource.py +72 -70
- nucliadb/ingest/orm/utils.py +1 -1
- nucliadb/ingest/processing.py +17 -17
- nucliadb/ingest/serialize.py +202 -145
- nucliadb/ingest/service/writer.py +3 -109
- nucliadb/ingest/settings.py +3 -4
- nucliadb/ingest/utils.py +1 -2
- nucliadb/learning_proxy.py +11 -11
- nucliadb/metrics_exporter.py +5 -4
- nucliadb/middleware/__init__.py +82 -1
- nucliadb/migrator/datamanager.py +3 -4
- nucliadb/migrator/migrator.py +1 -2
- nucliadb/migrator/models.py +1 -2
- nucliadb/migrator/settings.py +1 -2
- nucliadb/models/internal/augment.py +614 -0
- nucliadb/models/internal/processing.py +19 -19
- nucliadb/openapi.py +2 -2
- nucliadb/purge/__init__.py +3 -8
- nucliadb/purge/orphan_shards.py +1 -2
- nucliadb/reader/__init__.py +5 -0
- nucliadb/reader/api/models.py +6 -13
- nucliadb/reader/api/v1/download.py +59 -38
- nucliadb/reader/api/v1/export_import.py +4 -4
- nucliadb/reader/api/v1/learning_config.py +24 -4
- nucliadb/reader/api/v1/resource.py +61 -9
- nucliadb/reader/api/v1/services.py +18 -14
- nucliadb/reader/app.py +3 -1
- nucliadb/reader/reader/notifications.py +1 -2
- nucliadb/search/api/v1/__init__.py +2 -0
- nucliadb/search/api/v1/ask.py +3 -4
- nucliadb/search/api/v1/augment.py +585 -0
- nucliadb/search/api/v1/catalog.py +11 -15
- nucliadb/search/api/v1/find.py +16 -22
- nucliadb/search/api/v1/hydrate.py +25 -25
- nucliadb/search/api/v1/knowledgebox.py +1 -2
- nucliadb/search/api/v1/predict_proxy.py +1 -2
- nucliadb/search/api/v1/resource/ask.py +7 -7
- nucliadb/search/api/v1/resource/ingestion_agents.py +5 -6
- nucliadb/search/api/v1/resource/search.py +9 -11
- nucliadb/search/api/v1/retrieve.py +130 -0
- nucliadb/search/api/v1/search.py +28 -32
- nucliadb/search/api/v1/suggest.py +11 -14
- nucliadb/search/api/v1/summarize.py +1 -2
- nucliadb/search/api/v1/utils.py +2 -2
- nucliadb/search/app.py +3 -2
- nucliadb/search/augmentor/__init__.py +21 -0
- nucliadb/search/augmentor/augmentor.py +232 -0
- nucliadb/search/augmentor/fields.py +704 -0
- nucliadb/search/augmentor/metrics.py +24 -0
- nucliadb/search/augmentor/paragraphs.py +334 -0
- nucliadb/search/augmentor/resources.py +238 -0
- nucliadb/search/augmentor/utils.py +33 -0
- nucliadb/search/lifecycle.py +3 -1
- nucliadb/search/predict.py +24 -17
- nucliadb/search/predict_models.py +8 -9
- nucliadb/search/requesters/utils.py +11 -10
- nucliadb/search/search/cache.py +19 -23
- nucliadb/search/search/chat/ask.py +88 -59
- nucliadb/search/search/chat/exceptions.py +3 -5
- nucliadb/search/search/chat/fetcher.py +201 -0
- nucliadb/search/search/chat/images.py +6 -4
- nucliadb/search/search/chat/old_prompt.py +1375 -0
- nucliadb/search/search/chat/parser.py +510 -0
- nucliadb/search/search/chat/prompt.py +563 -615
- nucliadb/search/search/chat/query.py +449 -36
- nucliadb/search/search/chat/rpc.py +85 -0
- nucliadb/search/search/fetch.py +3 -4
- nucliadb/search/search/filters.py +8 -11
- nucliadb/search/search/find.py +33 -31
- nucliadb/search/search/find_merge.py +124 -331
- nucliadb/search/search/graph_strategy.py +14 -12
- nucliadb/search/search/hydrator/__init__.py +3 -152
- nucliadb/search/search/hydrator/fields.py +92 -50
- nucliadb/search/search/hydrator/images.py +7 -7
- nucliadb/search/search/hydrator/paragraphs.py +42 -26
- nucliadb/search/search/hydrator/resources.py +20 -16
- nucliadb/search/search/ingestion_agents.py +5 -5
- nucliadb/search/search/merge.py +90 -94
- nucliadb/search/search/metrics.py +10 -9
- nucliadb/search/search/paragraphs.py +7 -9
- nucliadb/search/search/predict_proxy.py +13 -9
- nucliadb/search/search/query.py +14 -86
- nucliadb/search/search/query_parser/fetcher.py +51 -82
- nucliadb/search/search/query_parser/models.py +19 -20
- nucliadb/search/search/query_parser/old_filters.py +20 -19
- nucliadb/search/search/query_parser/parsers/ask.py +4 -5
- nucliadb/search/search/query_parser/parsers/catalog.py +5 -6
- nucliadb/search/search/query_parser/parsers/common.py +5 -6
- nucliadb/search/search/query_parser/parsers/find.py +6 -26
- nucliadb/search/search/query_parser/parsers/graph.py +13 -23
- nucliadb/search/search/query_parser/parsers/retrieve.py +207 -0
- nucliadb/search/search/query_parser/parsers/search.py +15 -53
- nucliadb/search/search/query_parser/parsers/unit_retrieval.py +8 -29
- nucliadb/search/search/rank_fusion.py +18 -13
- nucliadb/search/search/rerankers.py +5 -6
- nucliadb/search/search/retrieval.py +300 -0
- nucliadb/search/search/summarize.py +5 -6
- nucliadb/search/search/utils.py +3 -4
- nucliadb/search/settings.py +1 -2
- nucliadb/standalone/api_router.py +1 -1
- nucliadb/standalone/app.py +4 -3
- nucliadb/standalone/auth.py +5 -6
- nucliadb/standalone/lifecycle.py +2 -2
- nucliadb/standalone/run.py +2 -4
- nucliadb/standalone/settings.py +5 -6
- nucliadb/standalone/versions.py +3 -4
- nucliadb/tasks/consumer.py +13 -8
- nucliadb/tasks/models.py +2 -1
- nucliadb/tasks/producer.py +3 -3
- nucliadb/tasks/retries.py +8 -7
- nucliadb/train/api/utils.py +1 -3
- nucliadb/train/api/v1/shards.py +1 -2
- nucliadb/train/api/v1/trainset.py +1 -2
- nucliadb/train/app.py +1 -1
- nucliadb/train/generator.py +4 -4
- nucliadb/train/generators/field_classifier.py +2 -2
- nucliadb/train/generators/field_streaming.py +6 -6
- nucliadb/train/generators/image_classifier.py +2 -2
- nucliadb/train/generators/paragraph_classifier.py +2 -2
- nucliadb/train/generators/paragraph_streaming.py +2 -2
- nucliadb/train/generators/question_answer_streaming.py +2 -2
- nucliadb/train/generators/sentence_classifier.py +2 -2
- nucliadb/train/generators/token_classifier.py +3 -2
- nucliadb/train/generators/utils.py +6 -5
- nucliadb/train/nodes.py +3 -3
- nucliadb/train/resource.py +6 -8
- nucliadb/train/settings.py +3 -4
- nucliadb/train/types.py +11 -11
- nucliadb/train/upload.py +3 -2
- nucliadb/train/uploader.py +1 -2
- nucliadb/train/utils.py +1 -2
- nucliadb/writer/api/v1/export_import.py +4 -1
- nucliadb/writer/api/v1/field.py +7 -11
- nucliadb/writer/api/v1/knowledgebox.py +3 -4
- nucliadb/writer/api/v1/resource.py +9 -20
- nucliadb/writer/api/v1/services.py +10 -132
- nucliadb/writer/api/v1/upload.py +73 -72
- nucliadb/writer/app.py +8 -2
- nucliadb/writer/resource/basic.py +12 -15
- nucliadb/writer/resource/field.py +7 -5
- nucliadb/writer/resource/origin.py +7 -0
- nucliadb/writer/settings.py +2 -3
- nucliadb/writer/tus/__init__.py +2 -3
- nucliadb/writer/tus/azure.py +1 -3
- nucliadb/writer/tus/dm.py +3 -3
- nucliadb/writer/tus/exceptions.py +3 -4
- nucliadb/writer/tus/gcs.py +5 -6
- nucliadb/writer/tus/s3.py +2 -3
- nucliadb/writer/tus/storage.py +3 -3
- {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/METADATA +9 -10
- nucliadb-6.10.0.post5705.dist-info/RECORD +410 -0
- nucliadb/common/datamanagers/entities.py +0 -139
- nucliadb-6.9.1.post5192.dist-info/RECORD +0 -392
- {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/WHEEL +0 -0
- {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/top_level.txt +0 -0
|
@@ -20,7 +20,8 @@
|
|
|
20
20
|
import dataclasses
|
|
21
21
|
import functools
|
|
22
22
|
import json
|
|
23
|
-
from
|
|
23
|
+
from collections.abc import AsyncGenerator
|
|
24
|
+
from typing import cast
|
|
24
25
|
|
|
25
26
|
from nuclia_models.common.consumption import Consumption
|
|
26
27
|
from nuclia_models.predict.generative_responses import (
|
|
@@ -34,6 +35,7 @@ from nuclia_models.predict.generative_responses import (
|
|
|
34
35
|
TextGenerativeResponse,
|
|
35
36
|
)
|
|
36
37
|
from pydantic_core import ValidationError
|
|
38
|
+
from typing_extensions import assert_never
|
|
37
39
|
|
|
38
40
|
from nucliadb.common.datamanagers.exceptions import KnowledgeBoxNotFound
|
|
39
41
|
from nucliadb.common.exceptions import InvalidQueryError
|
|
@@ -49,11 +51,13 @@ from nucliadb.search.search.chat.exceptions import (
|
|
|
49
51
|
AnswerJsonSchemaTooLong,
|
|
50
52
|
NoRetrievalResultsError,
|
|
51
53
|
)
|
|
54
|
+
from nucliadb.search.search.chat.old_prompt import PromptContextBuilder as OldPromptContextBuilder
|
|
52
55
|
from nucliadb.search.search.chat.prompt import PromptContextBuilder
|
|
53
56
|
from nucliadb.search.search.chat.query import (
|
|
54
57
|
NOT_ENOUGH_CONTEXT_ANSWER,
|
|
55
58
|
ChatAuditor,
|
|
56
59
|
add_resource_filter,
|
|
60
|
+
get_answer_stream,
|
|
57
61
|
get_find_results,
|
|
58
62
|
get_relations_results,
|
|
59
63
|
maybe_audit_chat,
|
|
@@ -69,11 +73,15 @@ from nucliadb.search.search.metrics import AskMetrics, Metrics
|
|
|
69
73
|
from nucliadb.search.search.query_parser.fetcher import Fetcher
|
|
70
74
|
from nucliadb.search.search.query_parser.parsers.ask import fetcher_for_ask, parse_ask
|
|
71
75
|
from nucliadb.search.search.rank_fusion import WeightedCombSum
|
|
72
|
-
from
|
|
73
|
-
|
|
76
|
+
from nucliadb_models.retrieval import (
|
|
77
|
+
GraphScore,
|
|
78
|
+
KeywordScore,
|
|
79
|
+
RerankerScore,
|
|
80
|
+
RrfScore,
|
|
81
|
+
SemanticScore,
|
|
74
82
|
)
|
|
75
|
-
from nucliadb.search.utilities import get_predict
|
|
76
83
|
from nucliadb_models.search import (
|
|
84
|
+
SCORE_TYPE,
|
|
77
85
|
AnswerAskResponseItem,
|
|
78
86
|
AskRequest,
|
|
79
87
|
AskResponseItem,
|
|
@@ -118,7 +126,9 @@ from nucliadb_models.search import (
|
|
|
118
126
|
parse_rephrase_prompt,
|
|
119
127
|
)
|
|
120
128
|
from nucliadb_telemetry import errors
|
|
129
|
+
from nucliadb_utils import const
|
|
121
130
|
from nucliadb_utils.exceptions import LimitsExceededError
|
|
131
|
+
from nucliadb_utils.utilities import has_feature
|
|
122
132
|
|
|
123
133
|
|
|
124
134
|
@dataclasses.dataclass
|
|
@@ -132,7 +142,7 @@ class RetrievalResults:
|
|
|
132
142
|
main_query: KnowledgeboxFindResults
|
|
133
143
|
fetcher: Fetcher
|
|
134
144
|
main_query_weight: float
|
|
135
|
-
prequeries:
|
|
145
|
+
prequeries: list[PreQueryResult] | None = None
|
|
136
146
|
best_matches: list[RetrievalMatch] = dataclasses.field(default_factory=list)
|
|
137
147
|
|
|
138
148
|
|
|
@@ -143,15 +153,15 @@ class AskResult:
|
|
|
143
153
|
kbid: str,
|
|
144
154
|
ask_request: AskRequest,
|
|
145
155
|
main_results: KnowledgeboxFindResults,
|
|
146
|
-
prequeries_results:
|
|
147
|
-
nuclia_learning_id:
|
|
148
|
-
predict_answer_stream:
|
|
156
|
+
prequeries_results: list[PreQueryResult] | None,
|
|
157
|
+
nuclia_learning_id: str | None,
|
|
158
|
+
predict_answer_stream: AsyncGenerator[GenerativeChunk, None] | None,
|
|
149
159
|
prompt_context: PromptContext,
|
|
150
160
|
prompt_context_order: PromptContextOrder,
|
|
151
161
|
auditor: ChatAuditor,
|
|
152
162
|
metrics: AskMetrics,
|
|
153
163
|
best_matches: list[RetrievalMatch],
|
|
154
|
-
debug_chat_model:
|
|
164
|
+
debug_chat_model: ChatModel | None,
|
|
155
165
|
augmented_context: AugmentedContext,
|
|
156
166
|
):
|
|
157
167
|
# Initial attributes
|
|
@@ -171,14 +181,14 @@ class AskResult:
|
|
|
171
181
|
|
|
172
182
|
# Computed from the predict chat answer stream
|
|
173
183
|
self._answer_text = ""
|
|
174
|
-
self._reasoning_text:
|
|
175
|
-
self._object:
|
|
176
|
-
self._status:
|
|
177
|
-
self._citations:
|
|
178
|
-
self._footnote_citations:
|
|
179
|
-
self._metadata:
|
|
180
|
-
self._relations:
|
|
181
|
-
self._consumption:
|
|
184
|
+
self._reasoning_text: str | None = None
|
|
185
|
+
self._object: JSONGenerativeResponse | None = None
|
|
186
|
+
self._status: StatusGenerativeResponse | None = None
|
|
187
|
+
self._citations: CitationsGenerativeResponse | None = None
|
|
188
|
+
self._footnote_citations: FootnoteCitationsGenerativeResponse | None = None
|
|
189
|
+
self._metadata: MetaGenerativeResponse | None = None
|
|
190
|
+
self._relations: Relations | None = None
|
|
191
|
+
self._consumption: Consumption | None = None
|
|
182
192
|
|
|
183
193
|
@property
|
|
184
194
|
def status_code(self) -> AnswerStatusCode:
|
|
@@ -187,7 +197,7 @@ class AskResult:
|
|
|
187
197
|
return AnswerStatusCode(self._status.code)
|
|
188
198
|
|
|
189
199
|
@property
|
|
190
|
-
def status_error_details(self) ->
|
|
200
|
+
def status_error_details(self) -> str | None:
|
|
191
201
|
if self._status is None: # pragma: no cover
|
|
192
202
|
return None
|
|
193
203
|
return self._status.details
|
|
@@ -240,9 +250,7 @@ class AskResult:
|
|
|
240
250
|
self.metrics.record_first_reasoning_chunk_yielded()
|
|
241
251
|
first_reasoning_chunk_yielded = True
|
|
242
252
|
else:
|
|
243
|
-
|
|
244
|
-
# that is, if we are missing some ifs
|
|
245
|
-
_a: int = "a"
|
|
253
|
+
assert_never(answer_chunk)
|
|
246
254
|
|
|
247
255
|
if self._object is not None:
|
|
248
256
|
yield JSONAskResponseItem(object=self._object.object)
|
|
@@ -396,7 +404,7 @@ class AskResult:
|
|
|
396
404
|
if self._object is not None:
|
|
397
405
|
answer_json = self._object.object
|
|
398
406
|
|
|
399
|
-
prequeries_results:
|
|
407
|
+
prequeries_results: dict[str, KnowledgeboxFindResults] | None = None
|
|
400
408
|
if self.prequeries_results:
|
|
401
409
|
prequeries_results = {}
|
|
402
410
|
for index, (prequery, result) in enumerate(self.prequeries_results):
|
|
@@ -452,7 +460,7 @@ class AskResult:
|
|
|
452
460
|
|
|
453
461
|
async def _stream_predict_answer_text(
|
|
454
462
|
self,
|
|
455
|
-
) -> AsyncGenerator[
|
|
463
|
+
) -> AsyncGenerator[TextGenerativeResponse | ReasoningGenerativeResponse, None]:
|
|
456
464
|
"""
|
|
457
465
|
Reads the stream of the generative model, yielding the answer text but also parsing
|
|
458
466
|
other items like status codes, citations and miscellaneous metadata.
|
|
@@ -496,8 +504,8 @@ class AskResult:
|
|
|
496
504
|
class NotEnoughContextAskResult(AskResult):
|
|
497
505
|
def __init__(
|
|
498
506
|
self,
|
|
499
|
-
main_results:
|
|
500
|
-
prequeries_results:
|
|
507
|
+
main_results: KnowledgeboxFindResults | None = None,
|
|
508
|
+
prequeries_results: list[PreQueryResult] | None = None,
|
|
501
509
|
):
|
|
502
510
|
self.main_results = main_results or KnowledgeboxFindResults(resources={}, min_score=None)
|
|
503
511
|
self.prequeries_results = prequeries_results or []
|
|
@@ -547,8 +555,8 @@ async def ask(
|
|
|
547
555
|
user_id: str,
|
|
548
556
|
client_type: NucliaDBClientType,
|
|
549
557
|
origin: str,
|
|
550
|
-
resource:
|
|
551
|
-
extra_predict_headers:
|
|
558
|
+
resource: str | None = None,
|
|
559
|
+
extra_predict_headers: dict[str, str] | None = None,
|
|
552
560
|
) -> AskResult:
|
|
553
561
|
metrics = AskMetrics()
|
|
554
562
|
chat_history = ask_request.chat_history or []
|
|
@@ -627,19 +635,36 @@ async def ask(
|
|
|
627
635
|
|
|
628
636
|
# Now we build the prompt context
|
|
629
637
|
with metrics.time("context_building"):
|
|
630
|
-
prompt_context_builder
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
638
|
+
prompt_context_builder: PromptContextBuilder | OldPromptContextBuilder
|
|
639
|
+
if has_feature(const.Features.ASK_DECOUPLED, context={"kbid": kbid}):
|
|
640
|
+
prompt_context_builder = PromptContextBuilder(
|
|
641
|
+
kbid=kbid,
|
|
642
|
+
ordered_paragraphs=[match.paragraph for match in retrieval_results.best_matches],
|
|
643
|
+
resource=resource,
|
|
644
|
+
user_context=user_context,
|
|
645
|
+
user_image_context=ask_request.extra_context_images,
|
|
646
|
+
strategies=ask_request.rag_strategies,
|
|
647
|
+
image_strategies=ask_request.rag_images_strategies,
|
|
648
|
+
max_context_characters=tokens_to_chars(generation.max_context_tokens),
|
|
649
|
+
visual_llm=generation.use_visual_llm,
|
|
650
|
+
query_image=ask_request.query_image,
|
|
651
|
+
metrics=metrics.child_span("context_building"),
|
|
652
|
+
)
|
|
653
|
+
else:
|
|
654
|
+
prompt_context_builder = OldPromptContextBuilder(
|
|
655
|
+
kbid=kbid,
|
|
656
|
+
ordered_paragraphs=[match.paragraph for match in retrieval_results.best_matches],
|
|
657
|
+
resource=resource,
|
|
658
|
+
user_context=user_context,
|
|
659
|
+
user_image_context=ask_request.extra_context_images,
|
|
660
|
+
strategies=ask_request.rag_strategies,
|
|
661
|
+
image_strategies=ask_request.rag_images_strategies,
|
|
662
|
+
max_context_characters=tokens_to_chars(generation.max_context_tokens),
|
|
663
|
+
visual_llm=generation.use_visual_llm,
|
|
664
|
+
query_image=ask_request.query_image,
|
|
665
|
+
metrics=metrics.child_span("context_building"),
|
|
666
|
+
)
|
|
667
|
+
|
|
643
668
|
(
|
|
644
669
|
prompt_context,
|
|
645
670
|
prompt_context_order,
|
|
@@ -675,14 +700,11 @@ async def ask(
|
|
|
675
700
|
predict_answer_stream = None
|
|
676
701
|
if ask_request.generate_answer:
|
|
677
702
|
with metrics.time("stream_start"):
|
|
678
|
-
predict = get_predict()
|
|
679
703
|
(
|
|
680
704
|
nuclia_learning_id,
|
|
681
705
|
nuclia_learning_model,
|
|
682
706
|
predict_answer_stream,
|
|
683
|
-
) = await
|
|
684
|
-
kbid=kbid, item=chat_model, extra_headers=extra_predict_headers
|
|
685
|
-
)
|
|
707
|
+
) = await get_answer_stream(kbid=kbid, item=chat_model, extra_headers=extra_predict_headers)
|
|
686
708
|
|
|
687
709
|
auditor = ChatAuditor(
|
|
688
710
|
kbid=kbid,
|
|
@@ -757,7 +779,7 @@ def handled_ask_exceptions(func):
|
|
|
757
779
|
return wrapper
|
|
758
780
|
|
|
759
781
|
|
|
760
|
-
def parse_prequeries(ask_request: AskRequest) ->
|
|
782
|
+
def parse_prequeries(ask_request: AskRequest) -> PreQueriesStrategy | None:
|
|
761
783
|
query_ids = []
|
|
762
784
|
for rag_strategy in ask_request.rag_strategies:
|
|
763
785
|
if rag_strategy.name == RagStrategyName.PREQUERIES:
|
|
@@ -776,7 +798,7 @@ def parse_prequeries(ask_request: AskRequest) -> Optional[PreQueriesStrategy]:
|
|
|
776
798
|
return None
|
|
777
799
|
|
|
778
800
|
|
|
779
|
-
def parse_graph_strategy(ask_request: AskRequest) ->
|
|
801
|
+
def parse_graph_strategy(ask_request: AskRequest) -> GraphStrategy | None:
|
|
780
802
|
for rag_strategy in ask_request.rag_strategies:
|
|
781
803
|
if rag_strategy.name == RagStrategyName.GRAPH:
|
|
782
804
|
return cast(GraphStrategy, rag_strategy)
|
|
@@ -791,7 +813,7 @@ async def retrieval_step(
|
|
|
791
813
|
user_id: str,
|
|
792
814
|
origin: str,
|
|
793
815
|
metrics: Metrics,
|
|
794
|
-
resource:
|
|
816
|
+
resource: str | None = None,
|
|
795
817
|
) -> RetrievalResults:
|
|
796
818
|
"""
|
|
797
819
|
This function encapsulates all the logic related to retrieval in the ask endpoint.
|
|
@@ -830,7 +852,7 @@ async def retrieval_in_kb(
|
|
|
830
852
|
) -> RetrievalResults:
|
|
831
853
|
prequeries = parse_prequeries(ask_request)
|
|
832
854
|
graph_strategy = parse_graph_strategy(ask_request)
|
|
833
|
-
main_results, prequeries_results,
|
|
855
|
+
main_results, prequeries_results, fetcher, reranker = await get_find_results(
|
|
834
856
|
kbid=kbid,
|
|
835
857
|
query=main_query,
|
|
836
858
|
item=ask_request,
|
|
@@ -842,10 +864,6 @@ async def retrieval_in_kb(
|
|
|
842
864
|
)
|
|
843
865
|
|
|
844
866
|
if graph_strategy is not None:
|
|
845
|
-
assert parsed_query.retrieval.reranker is not None, (
|
|
846
|
-
"find parser must provide a reranking algorithm"
|
|
847
|
-
)
|
|
848
|
-
reranker = get_reranker(parsed_query.retrieval.reranker)
|
|
849
867
|
graph_results, graph_request = await get_graph_results(
|
|
850
868
|
kbid=kbid,
|
|
851
869
|
query=main_query,
|
|
@@ -878,7 +896,7 @@ async def retrieval_in_kb(
|
|
|
878
896
|
return RetrievalResults(
|
|
879
897
|
main_query=main_results,
|
|
880
898
|
prequeries=prequeries_results,
|
|
881
|
-
fetcher=
|
|
899
|
+
fetcher=fetcher,
|
|
882
900
|
main_query_weight=main_query_weight,
|
|
883
901
|
best_matches=best_matches,
|
|
884
902
|
)
|
|
@@ -918,7 +936,7 @@ async def retrieval_in_resource(
|
|
|
918
936
|
)
|
|
919
937
|
add_resource_filter(prequery.request, [resource])
|
|
920
938
|
|
|
921
|
-
main_results, prequeries_results,
|
|
939
|
+
main_results, prequeries_results, fetcher, _ = await get_find_results(
|
|
922
940
|
kbid=kbid,
|
|
923
941
|
query=main_query,
|
|
924
942
|
item=ask_request,
|
|
@@ -941,7 +959,7 @@ async def retrieval_in_resource(
|
|
|
941
959
|
return RetrievalResults(
|
|
942
960
|
main_query=main_results,
|
|
943
961
|
prequeries=prequeries_results,
|
|
944
|
-
fetcher=
|
|
962
|
+
fetcher=fetcher,
|
|
945
963
|
main_query_weight=main_query_weight,
|
|
946
964
|
best_matches=best_matches,
|
|
947
965
|
)
|
|
@@ -953,7 +971,7 @@ class _FindParagraph(ScoredTextBlock):
|
|
|
953
971
|
|
|
954
972
|
def compute_best_matches(
|
|
955
973
|
main_results: KnowledgeboxFindResults,
|
|
956
|
-
prequeries_results:
|
|
974
|
+
prequeries_results: list[PreQueryResult] | None = None,
|
|
957
975
|
main_query_weight: float = 1.0,
|
|
958
976
|
) -> list[RetrievalMatch]:
|
|
959
977
|
"""
|
|
@@ -968,15 +986,27 @@ def compute_best_matches(
|
|
|
968
986
|
`main_query_weight` is the weight given to the paragraphs matching the main query when calculating the final score.
|
|
969
987
|
"""
|
|
970
988
|
|
|
989
|
+
score_type_map = {
|
|
990
|
+
SCORE_TYPE.VECTOR: SemanticScore,
|
|
991
|
+
SCORE_TYPE.BM25: KeywordScore,
|
|
992
|
+
SCORE_TYPE.BOTH: RrfScore, # /find only exposes RRF as rank fusion algorithm
|
|
993
|
+
SCORE_TYPE.RERANKER: RerankerScore,
|
|
994
|
+
SCORE_TYPE.RELATION_RELEVANCE: GraphScore,
|
|
995
|
+
}
|
|
996
|
+
|
|
971
997
|
def extract_paragraphs(results: KnowledgeboxFindResults) -> list[_FindParagraph]:
|
|
972
998
|
paragraphs = []
|
|
973
999
|
for resource in results.resources.values():
|
|
974
1000
|
for field in resource.fields.values():
|
|
975
1001
|
for paragraph in field.paragraphs.values():
|
|
1002
|
+
# TODO(decoupled-ask): we don't know the score history, as
|
|
1003
|
+
# we are using find results. Once we move boolean queries
|
|
1004
|
+
# inside the new retrieval flow we'll move this and have the
|
|
1005
|
+
# proper information to do this rank fusion
|
|
976
1006
|
paragraphs.append(
|
|
977
1007
|
_FindParagraph(
|
|
978
1008
|
paragraph_id=ParagraphId.from_string(paragraph.id),
|
|
979
|
-
score=paragraph.score,
|
|
1009
|
+
scores=[score_type_map[paragraph.score_type](score=paragraph.score)],
|
|
980
1010
|
score_type=paragraph.score_type,
|
|
981
1011
|
original=paragraph,
|
|
982
1012
|
)
|
|
@@ -1012,7 +1042,7 @@ def compute_best_matches(
|
|
|
1012
1042
|
|
|
1013
1043
|
def calculate_prequeries_for_json_schema(
|
|
1014
1044
|
ask_request: AskRequest,
|
|
1015
|
-
) ->
|
|
1045
|
+
) -> PreQueriesStrategy | None:
|
|
1016
1046
|
"""
|
|
1017
1047
|
This function generates a PreQueriesStrategy with a query for each property in the JSON schema
|
|
1018
1048
|
found in ask_request.answer_json_schema.
|
|
@@ -1077,7 +1107,6 @@ def calculate_prequeries_for_json_schema(
|
|
|
1077
1107
|
rephrase=ask_request.rephrase,
|
|
1078
1108
|
rephrase_prompt=parse_rephrase_prompt(ask_request),
|
|
1079
1109
|
security=ask_request.security,
|
|
1080
|
-
autofilter=False,
|
|
1081
1110
|
)
|
|
1082
1111
|
prequery = PreQuery(
|
|
1083
1112
|
request=req,
|
|
@@ -19,17 +19,15 @@
|
|
|
19
19
|
#
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
from typing import Optional
|
|
23
|
-
|
|
24
22
|
from nucliadb_models.search import KnowledgeboxFindResults, PreQueryResult
|
|
25
23
|
|
|
26
24
|
|
|
27
25
|
class NoRetrievalResultsError(Exception):
|
|
28
26
|
def __init__(
|
|
29
27
|
self,
|
|
30
|
-
main:
|
|
31
|
-
prequeries:
|
|
32
|
-
prefilters:
|
|
28
|
+
main: KnowledgeboxFindResults | None = None,
|
|
29
|
+
prequeries: list[PreQueryResult] | None = None,
|
|
30
|
+
prefilters: list[PreQueryResult] | None = None,
|
|
33
31
|
):
|
|
34
32
|
self.main_query = main
|
|
35
33
|
self.prequeries = prequeries
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
# Copyright (C) 2021 Bosutech XXI S.L.
|
|
2
|
+
#
|
|
3
|
+
# nucliadb is offered under the AGPL v3.0 and as commercial software.
|
|
4
|
+
# For commercial licensing, contact us at info@nuclia.com.
|
|
5
|
+
#
|
|
6
|
+
# AGPL:
|
|
7
|
+
# This program is free software: you can redistribute it and/or modify
|
|
8
|
+
# it under the terms of the GNU Affero General Public License as
|
|
9
|
+
# published by the Free Software Foundation, either version 3 of the
|
|
10
|
+
# License, or (at your option) any later version.
|
|
11
|
+
#
|
|
12
|
+
# This program is distributed in the hope that it will be useful,
|
|
13
|
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
14
|
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
15
|
+
# GNU Affero General Public License for more details.
|
|
16
|
+
#
|
|
17
|
+
# You should have received a copy of the GNU Affero General Public License
|
|
18
|
+
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
|
+
#
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
from google.protobuf.json_format import ParseDict
|
|
23
|
+
|
|
24
|
+
from nucliadb.common.exceptions import InvalidQueryError
|
|
25
|
+
from nucliadb.search import logger
|
|
26
|
+
from nucliadb.search.predict import SendToPredictError, convert_relations
|
|
27
|
+
from nucliadb.search.predict_models import QueryModel
|
|
28
|
+
from nucliadb.search.search.chat import rpc
|
|
29
|
+
from nucliadb.search.search.query_parser.fetcher import Fetcher
|
|
30
|
+
from nucliadb.search.utilities import get_predict
|
|
31
|
+
from nucliadb_models.internal.predict import QueryInfo
|
|
32
|
+
from nucliadb_models.search import Image, MaxTokens
|
|
33
|
+
from nucliadb_protos import knowledgebox_pb2, utils_pb2
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class RAOFetcher(Fetcher):
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
kbid: str,
|
|
40
|
+
*,
|
|
41
|
+
query: str,
|
|
42
|
+
user_vector: list[float] | None,
|
|
43
|
+
vectorset: str | None,
|
|
44
|
+
rephrase: bool,
|
|
45
|
+
rephrase_prompt: str | None,
|
|
46
|
+
generative_model: str | None,
|
|
47
|
+
query_image: Image | None,
|
|
48
|
+
):
|
|
49
|
+
super().__init__(
|
|
50
|
+
kbid,
|
|
51
|
+
query=query,
|
|
52
|
+
user_vector=user_vector,
|
|
53
|
+
vectorset=vectorset,
|
|
54
|
+
rephrase=rephrase,
|
|
55
|
+
rephrase_prompt=rephrase_prompt,
|
|
56
|
+
generative_model=generative_model,
|
|
57
|
+
query_image=query_image,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
self._query_info: QueryInfo | None = None
|
|
61
|
+
self._vectorset: str | None = None
|
|
62
|
+
|
|
63
|
+
async def query_information(self) -> QueryInfo:
|
|
64
|
+
if self._query_info is None:
|
|
65
|
+
self._query_info = await query_information(
|
|
66
|
+
kbid=self.kbid,
|
|
67
|
+
query=self.query,
|
|
68
|
+
semantic_model=self.user_vectorset,
|
|
69
|
+
generative_model=self.generative_model,
|
|
70
|
+
rephrase=self.rephrase,
|
|
71
|
+
rephrase_prompt=self.rephrase_prompt,
|
|
72
|
+
query_image=self.query_image,
|
|
73
|
+
)
|
|
74
|
+
return self._query_info
|
|
75
|
+
|
|
76
|
+
# Retrieval
|
|
77
|
+
|
|
78
|
+
async def get_rephrased_query(self) -> str | None:
|
|
79
|
+
query_info = await self.query_information()
|
|
80
|
+
return query_info.rephrased_query
|
|
81
|
+
|
|
82
|
+
async def get_detected_entities(self) -> list[utils_pb2.RelationNode]:
|
|
83
|
+
query_info = await self.query_information()
|
|
84
|
+
if query_info.entities is not None:
|
|
85
|
+
detected_entities = convert_relations(query_info.entities.model_dump())
|
|
86
|
+
else:
|
|
87
|
+
detected_entities = []
|
|
88
|
+
return detected_entities
|
|
89
|
+
|
|
90
|
+
async def get_semantic_min_score(self) -> float | None:
|
|
91
|
+
query_info = await self.query_information()
|
|
92
|
+
vectorset = await self.get_vectorset()
|
|
93
|
+
return query_info.semantic_thresholds.get(vectorset, None)
|
|
94
|
+
|
|
95
|
+
async def get_vectorset(self) -> str:
|
|
96
|
+
if self._vectorset is None:
|
|
97
|
+
if self.user_vectorset is not None:
|
|
98
|
+
self._vectorset = self.user_vectorset
|
|
99
|
+
else:
|
|
100
|
+
# when it's not provided, we get the default from Predict API
|
|
101
|
+
query_info = await self.query_information()
|
|
102
|
+
if query_info.sentence is None or len(query_info.sentence.vectors) == 0:
|
|
103
|
+
logger.error(
|
|
104
|
+
"Asking for a vectorset but /query didn't return one", extra={"kbid": self.kbid}
|
|
105
|
+
)
|
|
106
|
+
raise SendToPredictError("Predict API didn't return a sentence vectorset")
|
|
107
|
+
# vectors field is enforced by the data model to have at least one key
|
|
108
|
+
for vectorset in query_info.sentence.vectors.keys():
|
|
109
|
+
self._vectorset = vectorset
|
|
110
|
+
break
|
|
111
|
+
assert self._vectorset is not None
|
|
112
|
+
return self._vectorset
|
|
113
|
+
|
|
114
|
+
async def get_query_vector(self) -> list[float]:
|
|
115
|
+
if self.user_vector is not None:
|
|
116
|
+
return self.user_vector
|
|
117
|
+
|
|
118
|
+
query_info = await self.query_information()
|
|
119
|
+
if query_info.sentence is None:
|
|
120
|
+
logger.error(
|
|
121
|
+
"Asking for a semantic query vector but /query didn't return a sentence",
|
|
122
|
+
extra={"kbid": self.kbid},
|
|
123
|
+
)
|
|
124
|
+
raise SendToPredictError("Predict API didn't return a sentence for semantic search")
|
|
125
|
+
|
|
126
|
+
vectorset = await self.get_vectorset()
|
|
127
|
+
if vectorset not in query_info.sentence.vectors:
|
|
128
|
+
logger.error(
|
|
129
|
+
"Predict is not responding with a valid query nucliadb vectorset",
|
|
130
|
+
extra={
|
|
131
|
+
"kbid": self.kbid,
|
|
132
|
+
"vectorset": vectorset,
|
|
133
|
+
"predict_vectorsets": ",".join(query_info.sentence.vectors.keys()),
|
|
134
|
+
},
|
|
135
|
+
)
|
|
136
|
+
raise SendToPredictError("Predict API didn't return the requested vectorset")
|
|
137
|
+
|
|
138
|
+
query_vector = query_info.sentence.vectors[vectorset]
|
|
139
|
+
return query_vector
|
|
140
|
+
|
|
141
|
+
async def get_classification_labels(self) -> knowledgebox_pb2.Labels:
|
|
142
|
+
labelsets = await rpc.labelsets(self.kbid)
|
|
143
|
+
|
|
144
|
+
# TODO(decoupled-ask): remove this conversion and refactor code to use API models instead of protobuf
|
|
145
|
+
kb_labels = knowledgebox_pb2.Labels()
|
|
146
|
+
for labelset, labels in labelsets.labelsets.items():
|
|
147
|
+
ParseDict(labels.model_dump(), kb_labels.labelset[labelset])
|
|
148
|
+
|
|
149
|
+
return kb_labels
|
|
150
|
+
|
|
151
|
+
# Generative
|
|
152
|
+
|
|
153
|
+
async def get_visual_llm_enabled(self) -> bool:
|
|
154
|
+
query_info = await self.query_information()
|
|
155
|
+
if query_info is None:
|
|
156
|
+
raise SendToPredictError("Error while using predict's query endpoint")
|
|
157
|
+
|
|
158
|
+
return query_info.visual_llm
|
|
159
|
+
|
|
160
|
+
async def get_max_context_tokens(self, max_tokens: MaxTokens | None) -> int:
|
|
161
|
+
query_info = await self.query_information()
|
|
162
|
+
if query_info is None:
|
|
163
|
+
raise SendToPredictError("Error while using predict's query endpoint")
|
|
164
|
+
|
|
165
|
+
model_max = query_info.max_context
|
|
166
|
+
if max_tokens is not None and max_tokens.context is not None:
|
|
167
|
+
if max_tokens.context > model_max:
|
|
168
|
+
raise InvalidQueryError(
|
|
169
|
+
"max_tokens.context",
|
|
170
|
+
f"Max context tokens is higher than the model's limit of {model_max}",
|
|
171
|
+
)
|
|
172
|
+
return max_tokens.context
|
|
173
|
+
return model_max
|
|
174
|
+
|
|
175
|
+
def get_max_answer_tokens(self, max_tokens: MaxTokens | None) -> int | None:
|
|
176
|
+
if max_tokens is not None and max_tokens.answer is not None:
|
|
177
|
+
return max_tokens.answer
|
|
178
|
+
return None
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
async def query_information(
|
|
182
|
+
kbid: str,
|
|
183
|
+
query: str,
|
|
184
|
+
semantic_model: str | None,
|
|
185
|
+
generative_model: str | None = None,
|
|
186
|
+
rephrase: bool = False,
|
|
187
|
+
rephrase_prompt: str | None = None,
|
|
188
|
+
query_image: Image | None = None,
|
|
189
|
+
) -> QueryInfo:
|
|
190
|
+
# NOTE: When moving /ask to RAO, this will need to change to whatever client/utility is used
|
|
191
|
+
# to call NUA predict (internally or externally in the case of onprem).
|
|
192
|
+
predict = get_predict()
|
|
193
|
+
item = QueryModel(
|
|
194
|
+
text=query,
|
|
195
|
+
semantic_models=[semantic_model] if semantic_model else None,
|
|
196
|
+
generative_model=generative_model,
|
|
197
|
+
rephrase=rephrase,
|
|
198
|
+
rephrase_prompt=rephrase_prompt,
|
|
199
|
+
query_image=query_image,
|
|
200
|
+
)
|
|
201
|
+
return await predict.query(kbid, item)
|
|
@@ -19,7 +19,6 @@
|
|
|
19
19
|
|
|
20
20
|
import base64
|
|
21
21
|
from io import BytesIO
|
|
22
|
-
from typing import Optional
|
|
23
22
|
|
|
24
23
|
from nucliadb.common.ids import ParagraphId
|
|
25
24
|
from nucliadb.ingest.fields.file import File
|
|
@@ -29,7 +28,8 @@ from nucliadb_utils.storages.storage import Storage
|
|
|
29
28
|
from nucliadb_utils.utilities import get_storage
|
|
30
29
|
|
|
31
30
|
|
|
32
|
-
|
|
31
|
+
# DEPRECATED(decoupled-ask): remove once old_prompt.py is removed
|
|
32
|
+
async def get_page_image(kbid: str, paragraph_id: ParagraphId, page_number: int) -> Image | None:
|
|
33
33
|
storage = await get_storage(service_name=SERVICE_NAME)
|
|
34
34
|
sf = storage.file_extracted(
|
|
35
35
|
kbid=kbid,
|
|
@@ -48,7 +48,8 @@ async def get_page_image(kbid: str, paragraph_id: ParagraphId, page_number: int)
|
|
|
48
48
|
return image
|
|
49
49
|
|
|
50
50
|
|
|
51
|
-
|
|
51
|
+
# DEPRECATED(decoupled-ask): remove once old_prompt.py is removed
|
|
52
|
+
async def get_paragraph_image(kbid: str, paragraph_id: ParagraphId, reference: str) -> Image | None:
|
|
52
53
|
storage = await get_storage(service_name=SERVICE_NAME)
|
|
53
54
|
sf = storage.file_extracted(
|
|
54
55
|
kbid=kbid,
|
|
@@ -67,7 +68,8 @@ async def get_paragraph_image(kbid: str, paragraph_id: ParagraphId, reference: s
|
|
|
67
68
|
return image
|
|
68
69
|
|
|
69
70
|
|
|
70
|
-
|
|
71
|
+
# DEPRECATED(decoupled-ask): remove once old_prompt.py is removed
|
|
72
|
+
async def get_file_thumbnail_image(file: File) -> Image | None:
|
|
71
73
|
fed = await file.get_file_extracted_data()
|
|
72
74
|
if fed is None or not fed.HasField("file_thumbnail"):
|
|
73
75
|
return None
|