nucliadb 6.9.1.post5192__py3-none-any.whl → 6.10.0.post5705__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migrations/0023_backfill_pg_catalog.py +2 -2
- migrations/0029_backfill_field_status.py +3 -4
- migrations/0032_remove_old_relations.py +2 -3
- migrations/0038_backfill_catalog_field_labels.py +2 -2
- migrations/0039_backfill_converation_splits_metadata.py +2 -2
- migrations/0041_reindex_conversations.py +137 -0
- migrations/pg/0010_shards_index.py +34 -0
- nucliadb/search/api/v1/resource/utils.py → migrations/pg/0011_catalog_statistics.py +5 -6
- migrations/pg/0012_catalog_statistics_undo.py +26 -0
- nucliadb/backups/create.py +2 -15
- nucliadb/backups/restore.py +4 -15
- nucliadb/backups/tasks.py +4 -1
- nucliadb/common/back_pressure/cache.py +2 -3
- nucliadb/common/back_pressure/materializer.py +7 -13
- nucliadb/common/back_pressure/settings.py +6 -6
- nucliadb/common/back_pressure/utils.py +1 -0
- nucliadb/common/cache.py +9 -9
- nucliadb/common/catalog/interface.py +12 -12
- nucliadb/common/catalog/pg.py +41 -29
- nucliadb/common/catalog/utils.py +3 -3
- nucliadb/common/cluster/manager.py +5 -4
- nucliadb/common/cluster/rebalance.py +483 -114
- nucliadb/common/cluster/rollover.py +25 -9
- nucliadb/common/cluster/settings.py +3 -8
- nucliadb/common/cluster/utils.py +34 -8
- nucliadb/common/context/__init__.py +7 -8
- nucliadb/common/context/fastapi.py +1 -2
- nucliadb/common/datamanagers/__init__.py +2 -4
- nucliadb/common/datamanagers/atomic.py +4 -2
- nucliadb/common/datamanagers/cluster.py +1 -2
- nucliadb/common/datamanagers/fields.py +3 -4
- nucliadb/common/datamanagers/kb.py +6 -6
- nucliadb/common/datamanagers/labels.py +2 -3
- nucliadb/common/datamanagers/resources.py +10 -33
- nucliadb/common/datamanagers/rollover.py +5 -7
- nucliadb/common/datamanagers/search_configurations.py +1 -2
- nucliadb/common/datamanagers/synonyms.py +1 -2
- nucliadb/common/datamanagers/utils.py +4 -4
- nucliadb/common/datamanagers/vectorsets.py +4 -4
- nucliadb/common/external_index_providers/base.py +32 -5
- nucliadb/common/external_index_providers/manager.py +4 -5
- nucliadb/common/filter_expression.py +128 -40
- nucliadb/common/http_clients/processing.py +12 -23
- nucliadb/common/ids.py +6 -4
- nucliadb/common/locking.py +1 -2
- nucliadb/common/maindb/driver.py +9 -8
- nucliadb/common/maindb/local.py +5 -5
- nucliadb/common/maindb/pg.py +9 -8
- nucliadb/common/nidx.py +3 -4
- nucliadb/export_import/datamanager.py +4 -3
- nucliadb/export_import/exporter.py +11 -19
- nucliadb/export_import/importer.py +13 -6
- nucliadb/export_import/tasks.py +2 -0
- nucliadb/export_import/utils.py +6 -18
- nucliadb/health.py +2 -2
- nucliadb/ingest/app.py +8 -8
- nucliadb/ingest/consumer/consumer.py +8 -10
- nucliadb/ingest/consumer/pull.py +3 -8
- nucliadb/ingest/consumer/service.py +3 -3
- nucliadb/ingest/consumer/utils.py +1 -1
- nucliadb/ingest/fields/base.py +28 -49
- nucliadb/ingest/fields/conversation.py +12 -12
- nucliadb/ingest/fields/exceptions.py +1 -2
- nucliadb/ingest/fields/file.py +22 -8
- nucliadb/ingest/fields/link.py +7 -7
- nucliadb/ingest/fields/text.py +2 -3
- nucliadb/ingest/orm/brain_v2.py +78 -64
- nucliadb/ingest/orm/broker_message.py +2 -4
- nucliadb/ingest/orm/entities.py +10 -209
- nucliadb/ingest/orm/index_message.py +4 -4
- nucliadb/ingest/orm/knowledgebox.py +18 -27
- nucliadb/ingest/orm/processor/auditing.py +1 -3
- nucliadb/ingest/orm/processor/data_augmentation.py +1 -2
- nucliadb/ingest/orm/processor/processor.py +27 -27
- nucliadb/ingest/orm/processor/sequence_manager.py +1 -2
- nucliadb/ingest/orm/resource.py +72 -70
- nucliadb/ingest/orm/utils.py +1 -1
- nucliadb/ingest/processing.py +17 -17
- nucliadb/ingest/serialize.py +202 -145
- nucliadb/ingest/service/writer.py +3 -109
- nucliadb/ingest/settings.py +3 -4
- nucliadb/ingest/utils.py +1 -2
- nucliadb/learning_proxy.py +11 -11
- nucliadb/metrics_exporter.py +5 -4
- nucliadb/middleware/__init__.py +82 -1
- nucliadb/migrator/datamanager.py +3 -4
- nucliadb/migrator/migrator.py +1 -2
- nucliadb/migrator/models.py +1 -2
- nucliadb/migrator/settings.py +1 -2
- nucliadb/models/internal/augment.py +614 -0
- nucliadb/models/internal/processing.py +19 -19
- nucliadb/openapi.py +2 -2
- nucliadb/purge/__init__.py +3 -8
- nucliadb/purge/orphan_shards.py +1 -2
- nucliadb/reader/__init__.py +5 -0
- nucliadb/reader/api/models.py +6 -13
- nucliadb/reader/api/v1/download.py +59 -38
- nucliadb/reader/api/v1/export_import.py +4 -4
- nucliadb/reader/api/v1/learning_config.py +24 -4
- nucliadb/reader/api/v1/resource.py +61 -9
- nucliadb/reader/api/v1/services.py +18 -14
- nucliadb/reader/app.py +3 -1
- nucliadb/reader/reader/notifications.py +1 -2
- nucliadb/search/api/v1/__init__.py +2 -0
- nucliadb/search/api/v1/ask.py +3 -4
- nucliadb/search/api/v1/augment.py +585 -0
- nucliadb/search/api/v1/catalog.py +11 -15
- nucliadb/search/api/v1/find.py +16 -22
- nucliadb/search/api/v1/hydrate.py +25 -25
- nucliadb/search/api/v1/knowledgebox.py +1 -2
- nucliadb/search/api/v1/predict_proxy.py +1 -2
- nucliadb/search/api/v1/resource/ask.py +7 -7
- nucliadb/search/api/v1/resource/ingestion_agents.py +5 -6
- nucliadb/search/api/v1/resource/search.py +9 -11
- nucliadb/search/api/v1/retrieve.py +130 -0
- nucliadb/search/api/v1/search.py +28 -32
- nucliadb/search/api/v1/suggest.py +11 -14
- nucliadb/search/api/v1/summarize.py +1 -2
- nucliadb/search/api/v1/utils.py +2 -2
- nucliadb/search/app.py +3 -2
- nucliadb/search/augmentor/__init__.py +21 -0
- nucliadb/search/augmentor/augmentor.py +232 -0
- nucliadb/search/augmentor/fields.py +704 -0
- nucliadb/search/augmentor/metrics.py +24 -0
- nucliadb/search/augmentor/paragraphs.py +334 -0
- nucliadb/search/augmentor/resources.py +238 -0
- nucliadb/search/augmentor/utils.py +33 -0
- nucliadb/search/lifecycle.py +3 -1
- nucliadb/search/predict.py +24 -17
- nucliadb/search/predict_models.py +8 -9
- nucliadb/search/requesters/utils.py +11 -10
- nucliadb/search/search/cache.py +19 -23
- nucliadb/search/search/chat/ask.py +88 -59
- nucliadb/search/search/chat/exceptions.py +3 -5
- nucliadb/search/search/chat/fetcher.py +201 -0
- nucliadb/search/search/chat/images.py +6 -4
- nucliadb/search/search/chat/old_prompt.py +1375 -0
- nucliadb/search/search/chat/parser.py +510 -0
- nucliadb/search/search/chat/prompt.py +563 -615
- nucliadb/search/search/chat/query.py +449 -36
- nucliadb/search/search/chat/rpc.py +85 -0
- nucliadb/search/search/fetch.py +3 -4
- nucliadb/search/search/filters.py +8 -11
- nucliadb/search/search/find.py +33 -31
- nucliadb/search/search/find_merge.py +124 -331
- nucliadb/search/search/graph_strategy.py +14 -12
- nucliadb/search/search/hydrator/__init__.py +3 -152
- nucliadb/search/search/hydrator/fields.py +92 -50
- nucliadb/search/search/hydrator/images.py +7 -7
- nucliadb/search/search/hydrator/paragraphs.py +42 -26
- nucliadb/search/search/hydrator/resources.py +20 -16
- nucliadb/search/search/ingestion_agents.py +5 -5
- nucliadb/search/search/merge.py +90 -94
- nucliadb/search/search/metrics.py +10 -9
- nucliadb/search/search/paragraphs.py +7 -9
- nucliadb/search/search/predict_proxy.py +13 -9
- nucliadb/search/search/query.py +14 -86
- nucliadb/search/search/query_parser/fetcher.py +51 -82
- nucliadb/search/search/query_parser/models.py +19 -20
- nucliadb/search/search/query_parser/old_filters.py +20 -19
- nucliadb/search/search/query_parser/parsers/ask.py +4 -5
- nucliadb/search/search/query_parser/parsers/catalog.py +5 -6
- nucliadb/search/search/query_parser/parsers/common.py +5 -6
- nucliadb/search/search/query_parser/parsers/find.py +6 -26
- nucliadb/search/search/query_parser/parsers/graph.py +13 -23
- nucliadb/search/search/query_parser/parsers/retrieve.py +207 -0
- nucliadb/search/search/query_parser/parsers/search.py +15 -53
- nucliadb/search/search/query_parser/parsers/unit_retrieval.py +8 -29
- nucliadb/search/search/rank_fusion.py +18 -13
- nucliadb/search/search/rerankers.py +5 -6
- nucliadb/search/search/retrieval.py +300 -0
- nucliadb/search/search/summarize.py +5 -6
- nucliadb/search/search/utils.py +3 -4
- nucliadb/search/settings.py +1 -2
- nucliadb/standalone/api_router.py +1 -1
- nucliadb/standalone/app.py +4 -3
- nucliadb/standalone/auth.py +5 -6
- nucliadb/standalone/lifecycle.py +2 -2
- nucliadb/standalone/run.py +2 -4
- nucliadb/standalone/settings.py +5 -6
- nucliadb/standalone/versions.py +3 -4
- nucliadb/tasks/consumer.py +13 -8
- nucliadb/tasks/models.py +2 -1
- nucliadb/tasks/producer.py +3 -3
- nucliadb/tasks/retries.py +8 -7
- nucliadb/train/api/utils.py +1 -3
- nucliadb/train/api/v1/shards.py +1 -2
- nucliadb/train/api/v1/trainset.py +1 -2
- nucliadb/train/app.py +1 -1
- nucliadb/train/generator.py +4 -4
- nucliadb/train/generators/field_classifier.py +2 -2
- nucliadb/train/generators/field_streaming.py +6 -6
- nucliadb/train/generators/image_classifier.py +2 -2
- nucliadb/train/generators/paragraph_classifier.py +2 -2
- nucliadb/train/generators/paragraph_streaming.py +2 -2
- nucliadb/train/generators/question_answer_streaming.py +2 -2
- nucliadb/train/generators/sentence_classifier.py +2 -2
- nucliadb/train/generators/token_classifier.py +3 -2
- nucliadb/train/generators/utils.py +6 -5
- nucliadb/train/nodes.py +3 -3
- nucliadb/train/resource.py +6 -8
- nucliadb/train/settings.py +3 -4
- nucliadb/train/types.py +11 -11
- nucliadb/train/upload.py +3 -2
- nucliadb/train/uploader.py +1 -2
- nucliadb/train/utils.py +1 -2
- nucliadb/writer/api/v1/export_import.py +4 -1
- nucliadb/writer/api/v1/field.py +7 -11
- nucliadb/writer/api/v1/knowledgebox.py +3 -4
- nucliadb/writer/api/v1/resource.py +9 -20
- nucliadb/writer/api/v1/services.py +10 -132
- nucliadb/writer/api/v1/upload.py +73 -72
- nucliadb/writer/app.py +8 -2
- nucliadb/writer/resource/basic.py +12 -15
- nucliadb/writer/resource/field.py +7 -5
- nucliadb/writer/resource/origin.py +7 -0
- nucliadb/writer/settings.py +2 -3
- nucliadb/writer/tus/__init__.py +2 -3
- nucliadb/writer/tus/azure.py +1 -3
- nucliadb/writer/tus/dm.py +3 -3
- nucliadb/writer/tus/exceptions.py +3 -4
- nucliadb/writer/tus/gcs.py +5 -6
- nucliadb/writer/tus/s3.py +2 -3
- nucliadb/writer/tus/storage.py +3 -3
- {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/METADATA +9 -10
- nucliadb-6.10.0.post5705.dist-info/RECORD +410 -0
- nucliadb/common/datamanagers/entities.py +0 -139
- nucliadb-6.9.1.post5192.dist-info/RECORD +0 -392
- {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/WHEEL +0 -0
- {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/top_level.txt +0 -0
nucliadb/search/predict.py
CHANGED
|
@@ -22,9 +22,9 @@ import json
|
|
|
22
22
|
import logging
|
|
23
23
|
import os
|
|
24
24
|
import random
|
|
25
|
+
from collections.abc import AsyncGenerator
|
|
25
26
|
from dataclasses import dataclass
|
|
26
27
|
from enum import Enum
|
|
27
|
-
from typing import AsyncGenerator, Optional
|
|
28
28
|
from unittest.mock import AsyncMock, Mock
|
|
29
29
|
|
|
30
30
|
import aiohttp
|
|
@@ -144,7 +144,7 @@ class AnswerStatusCode(str, Enum):
|
|
|
144
144
|
@dataclass
|
|
145
145
|
class RephraseResponse:
|
|
146
146
|
rephrased_query: str
|
|
147
|
-
use_chat_history:
|
|
147
|
+
use_chat_history: bool | None
|
|
148
148
|
|
|
149
149
|
|
|
150
150
|
async def start_predict_engine():
|
|
@@ -176,18 +176,18 @@ def convert_relations(data: dict[str, list[dict[str, str]]]) -> list[RelationNod
|
|
|
176
176
|
class PredictEngine:
|
|
177
177
|
def __init__(
|
|
178
178
|
self,
|
|
179
|
-
cluster_url:
|
|
180
|
-
public_url:
|
|
181
|
-
nuclia_service_account:
|
|
182
|
-
zone:
|
|
179
|
+
cluster_url: str | None = None,
|
|
180
|
+
public_url: str | None = None,
|
|
181
|
+
nuclia_service_account: str | None = None,
|
|
182
|
+
zone: str | None = None,
|
|
183
183
|
onprem: bool = False,
|
|
184
184
|
local_predict: bool = False,
|
|
185
|
-
local_predict_headers:
|
|
185
|
+
local_predict_headers: dict[str, str] | None = None,
|
|
186
186
|
):
|
|
187
187
|
self.nuclia_service_account = nuclia_service_account
|
|
188
188
|
self.cluster_url = cluster_url
|
|
189
189
|
if public_url is not None:
|
|
190
|
-
self.public_url:
|
|
190
|
+
self.public_url: str | None = public_url.format(zone=zone)
|
|
191
191
|
else:
|
|
192
192
|
self.public_url = None
|
|
193
193
|
self.zone = zone
|
|
@@ -294,7 +294,7 @@ class PredictEngine:
|
|
|
294
294
|
|
|
295
295
|
@predict_observer.wrap({"type": "chat_ndjson"})
|
|
296
296
|
async def chat_query_ndjson(
|
|
297
|
-
self, kbid: str, item: ChatModel, extra_headers:
|
|
297
|
+
self, kbid: str, item: ChatModel, extra_headers: dict[str, str] | None = None
|
|
298
298
|
) -> tuple[str, str, AsyncGenerator[GenerativeChunk, None]]:
|
|
299
299
|
"""
|
|
300
300
|
Chat query using the new stream format
|
|
@@ -383,7 +383,7 @@ class PredictEngine:
|
|
|
383
383
|
|
|
384
384
|
@predict_observer.wrap({"type": "summarize"})
|
|
385
385
|
async def summarize(
|
|
386
|
-
self, kbid: str, item: SummarizeModel, extra_headers:
|
|
386
|
+
self, kbid: str, item: SummarizeModel, extra_headers: dict[str, str] | None = None
|
|
387
387
|
) -> SummarizedResponse:
|
|
388
388
|
try:
|
|
389
389
|
self.check_nua_key_is_configured_for_onprem()
|
|
@@ -481,7 +481,7 @@ class DummyPredictEngine(PredictEngine):
|
|
|
481
481
|
return RephraseResponse(rephrased_query=DUMMY_REPHRASE_QUERY, use_chat_history=None)
|
|
482
482
|
|
|
483
483
|
async def chat_query_ndjson(
|
|
484
|
-
self, kbid: str, item: ChatModel, extra_headers:
|
|
484
|
+
self, kbid: str, item: ChatModel, extra_headers: dict[str, str] | None = None
|
|
485
485
|
) -> tuple[str, str, AsyncGenerator[GenerativeChunk, None]]:
|
|
486
486
|
self.calls.append(("chat_query_ndjson", item))
|
|
487
487
|
|
|
@@ -524,10 +524,17 @@ class DummyPredictEngine(PredictEngine):
|
|
|
524
524
|
timings[vectorset_id] = 0.010
|
|
525
525
|
|
|
526
526
|
# and fake data with the passed one too
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
527
|
+
if item.semantic_models is not None:
|
|
528
|
+
for model in item.semantic_models:
|
|
529
|
+
semantic_thresholds[model] = self.default_semantic_threshold
|
|
530
|
+
vectors[model] = base_vector
|
|
531
|
+
timings[model] = 0.0
|
|
532
|
+
|
|
533
|
+
if len(vectors) == 0:
|
|
534
|
+
model = "<PREDICT-DEFAULT-SEMANTIC-MODEL>"
|
|
535
|
+
semantic_thresholds[model] = self.default_semantic_threshold
|
|
536
|
+
vectors[model] = base_vector
|
|
537
|
+
timings[model] = 0.0
|
|
531
538
|
|
|
532
539
|
return QueryInfo(
|
|
533
540
|
language="en",
|
|
@@ -540,7 +547,7 @@ class DummyPredictEngine(PredictEngine):
|
|
|
540
547
|
vectors=vectors,
|
|
541
548
|
timings=timings,
|
|
542
549
|
),
|
|
543
|
-
query=
|
|
550
|
+
query=item.text or "<PREDICT-QUERY>",
|
|
544
551
|
rephrased_query="<REPHRASED-QUERY>" if item.rephrase or item.query_image else None,
|
|
545
552
|
)
|
|
546
553
|
|
|
@@ -553,7 +560,7 @@ class DummyPredictEngine(PredictEngine):
|
|
|
553
560
|
return DUMMY_RELATION_NODE
|
|
554
561
|
|
|
555
562
|
async def summarize(
|
|
556
|
-
self, kbid: str, item: SummarizeModel, extra_headers:
|
|
563
|
+
self, kbid: str, item: SummarizeModel, extra_headers: dict[str, str] | None = None
|
|
557
564
|
) -> SummarizedResponse:
|
|
558
565
|
self.calls.append(("summarize", (kbid, item)))
|
|
559
566
|
response = SummarizedResponse(
|
|
@@ -19,7 +19,6 @@
|
|
|
19
19
|
|
|
20
20
|
from base64 import b64decode, b64encode
|
|
21
21
|
from enum import Enum
|
|
22
|
-
from typing import Optional
|
|
23
22
|
|
|
24
23
|
from google.protobuf.message import DecodeError, Message
|
|
25
24
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
@@ -77,7 +76,7 @@ class RunAgentsRequest(BaseModel):
|
|
|
77
76
|
default_factory=list,
|
|
78
77
|
title="An optional list of Data Augmentation Agent IDs to run. If empty, all configured agents that match the filters are run.",
|
|
79
78
|
)
|
|
80
|
-
filters:
|
|
79
|
+
filters: list[NameOperationFilter] | None = Field(
|
|
81
80
|
default=None,
|
|
82
81
|
title="Filters to select which Data Augmentation Agents are applied to the text. If empty, all configured agents for the Knowledge Box are applied.",
|
|
83
82
|
)
|
|
@@ -93,7 +92,7 @@ class AppliedDataAugmentation(BaseModel):
|
|
|
93
92
|
# Since we have protos as fields, we need to enable arbitrary_types_allowed
|
|
94
93
|
arbitrary_types_allowed=True,
|
|
95
94
|
)
|
|
96
|
-
qas:
|
|
95
|
+
qas: QuestionAnswers | None = Field(
|
|
97
96
|
default=None,
|
|
98
97
|
description="Question and answers generated by the Question Answers agent",
|
|
99
98
|
)
|
|
@@ -107,7 +106,7 @@ class AppliedDataAugmentation(BaseModel):
|
|
|
107
106
|
)
|
|
108
107
|
|
|
109
108
|
@field_validator("qas", mode="before")
|
|
110
|
-
def validate_qas(cls, qas:
|
|
109
|
+
def validate_qas(cls, qas: str | None) -> QuestionAnswers | None:
|
|
111
110
|
if qas is None:
|
|
112
111
|
return None
|
|
113
112
|
try:
|
|
@@ -171,8 +170,8 @@ class QueryModel(BaseModel):
|
|
|
171
170
|
Model to represent a query request
|
|
172
171
|
"""
|
|
173
172
|
|
|
174
|
-
text:
|
|
175
|
-
query_image:
|
|
173
|
+
text: str | None = Field(default=None, description="The query text to be processed")
|
|
174
|
+
query_image: Image | None = Field(
|
|
176
175
|
default=None,
|
|
177
176
|
description="Image to be considered as part of the query. Even if the `rephrase` parameter is set to `false`, the rephrasing process will occur, combining the provided text with the image's visual features in the rephrased query.",
|
|
178
177
|
)
|
|
@@ -180,7 +179,7 @@ class QueryModel(BaseModel):
|
|
|
180
179
|
default=False,
|
|
181
180
|
description="If true, the model will rephrase the input text before processing",
|
|
182
181
|
)
|
|
183
|
-
rephrase_prompt:
|
|
182
|
+
rephrase_prompt: str | None = Field(
|
|
184
183
|
default=None,
|
|
185
184
|
description="Custom prompt for rephrasing the input text",
|
|
186
185
|
examples=[
|
|
@@ -192,11 +191,11 @@ QUESTION: {question}
|
|
|
192
191
|
Please return ONLY the question without any explanation.""",
|
|
193
192
|
],
|
|
194
193
|
)
|
|
195
|
-
generative_model:
|
|
194
|
+
generative_model: str | None = Field(
|
|
196
195
|
default=None,
|
|
197
196
|
description="The generative model to use for rephrasing",
|
|
198
197
|
)
|
|
199
|
-
semantic_models:
|
|
198
|
+
semantic_models: list[str] | None = Field(
|
|
200
199
|
default=None,
|
|
201
200
|
description="Semantic models to compute the sentence vector for, if not provided, it will only compute the sentence vector for default semantic model in the Knowledge box's configuration.",
|
|
202
201
|
)
|
|
@@ -19,8 +19,9 @@
|
|
|
19
19
|
|
|
20
20
|
import asyncio
|
|
21
21
|
import json
|
|
22
|
+
from collections.abc import Sequence
|
|
22
23
|
from enum import Enum, auto
|
|
23
|
-
from typing import
|
|
24
|
+
from typing import TypeVar, overload
|
|
24
25
|
|
|
25
26
|
from fastapi import HTTPException
|
|
26
27
|
from google.protobuf.json_format import MessageToDict
|
|
@@ -60,7 +61,7 @@ METHODS = {
|
|
|
60
61
|
Method.GRAPH: graph_search_shard,
|
|
61
62
|
}
|
|
62
63
|
|
|
63
|
-
REQUEST_TYPE =
|
|
64
|
+
REQUEST_TYPE = SuggestRequest | SearchRequest | GraphSearchRequest
|
|
64
65
|
|
|
65
66
|
T = TypeVar(
|
|
66
67
|
"T",
|
|
@@ -75,7 +76,7 @@ async def nidx_query(
|
|
|
75
76
|
kbid: str,
|
|
76
77
|
method: Method,
|
|
77
78
|
pb_query: SuggestRequest,
|
|
78
|
-
timeout:
|
|
79
|
+
timeout: float | None = None,
|
|
79
80
|
) -> tuple[list[SuggestResponse], list[str]]: ...
|
|
80
81
|
|
|
81
82
|
|
|
@@ -84,7 +85,7 @@ async def nidx_query(
|
|
|
84
85
|
kbid: str,
|
|
85
86
|
method: Method,
|
|
86
87
|
pb_query: SearchRequest,
|
|
87
|
-
timeout:
|
|
88
|
+
timeout: float | None = None,
|
|
88
89
|
) -> tuple[list[SearchResponse], list[str]]: ...
|
|
89
90
|
|
|
90
91
|
|
|
@@ -93,7 +94,7 @@ async def nidx_query(
|
|
|
93
94
|
kbid: str,
|
|
94
95
|
method: Method,
|
|
95
96
|
pb_query: GraphSearchRequest,
|
|
96
|
-
timeout:
|
|
97
|
+
timeout: float | None = None,
|
|
97
98
|
) -> tuple[list[GraphSearchResponse], list[str]]: ...
|
|
98
99
|
|
|
99
100
|
|
|
@@ -101,8 +102,8 @@ async def nidx_query(
|
|
|
101
102
|
kbid: str,
|
|
102
103
|
method: Method,
|
|
103
104
|
pb_query: REQUEST_TYPE,
|
|
104
|
-
timeout:
|
|
105
|
-
) -> tuple[Sequence[
|
|
105
|
+
timeout: float | None = None,
|
|
106
|
+
) -> tuple[Sequence[T | BaseException], list[str]]:
|
|
106
107
|
timeout = timeout or settings.search_timeout
|
|
107
108
|
shard_manager = get_shard_manager()
|
|
108
109
|
try:
|
|
@@ -133,7 +134,7 @@ async def nidx_query(
|
|
|
133
134
|
)
|
|
134
135
|
|
|
135
136
|
try:
|
|
136
|
-
results: list[
|
|
137
|
+
results: list[T | BaseException] = await asyncio.wait_for(
|
|
137
138
|
asyncio.gather(*ops, return_exceptions=True),
|
|
138
139
|
timeout=timeout,
|
|
139
140
|
)
|
|
@@ -159,13 +160,13 @@ async def nidx_query(
|
|
|
159
160
|
return results, queried_shards
|
|
160
161
|
|
|
161
162
|
|
|
162
|
-
def validate_nidx_query_results(results: list[
|
|
163
|
+
def validate_nidx_query_results(results: list[T | BaseException]) -> HTTPException | None:
|
|
163
164
|
"""
|
|
164
165
|
Validate the results of a nidx query and return an exception if any error is found
|
|
165
166
|
|
|
166
167
|
Handling of exception is responsibility of caller.
|
|
167
168
|
"""
|
|
168
|
-
if
|
|
169
|
+
if len(results) == 0:
|
|
169
170
|
return HTTPException(status_code=500, detail=f"Error while executing shard queries. No results.")
|
|
170
171
|
|
|
171
172
|
for result in results:
|
nucliadb/search/search/cache.py
CHANGED
|
@@ -19,7 +19,6 @@
|
|
|
19
19
|
|
|
20
20
|
import contextlib
|
|
21
21
|
import logging
|
|
22
|
-
from typing import Optional
|
|
23
22
|
|
|
24
23
|
from nucliadb.common.cache import (
|
|
25
24
|
extracted_text_cache,
|
|
@@ -39,26 +38,35 @@ from nucliadb_utils.utilities import get_storage
|
|
|
39
38
|
logger = logging.getLogger(__name__)
|
|
40
39
|
|
|
41
40
|
|
|
42
|
-
async def get_resource(kbid: str, uuid: str) ->
|
|
41
|
+
async def get_resource(kbid: str, uuid: str) -> ResourceORM | None:
|
|
43
42
|
"""
|
|
44
43
|
Will try to get the resource from the cache, if it's not there it will fetch it from the ORM and cache it.
|
|
45
44
|
"""
|
|
46
45
|
resource_cache = get_resource_cache()
|
|
47
46
|
if resource_cache is None:
|
|
48
47
|
logger.warning("Resource cache not set")
|
|
49
|
-
|
|
48
|
+
async with get_driver().ro_transaction() as txn:
|
|
49
|
+
storage = await get_storage(service_name=SERVICE_NAME)
|
|
50
|
+
kb = KnowledgeBoxORM(txn, storage, kbid)
|
|
51
|
+
return await kb.get(uuid)
|
|
50
52
|
|
|
51
53
|
return await resource_cache.get(kbid, uuid)
|
|
52
54
|
|
|
53
55
|
|
|
54
|
-
async def
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
return
|
|
56
|
+
async def get_field(kbid: str, field_id: FieldId) -> Field | None:
|
|
57
|
+
rid = field_id.rid
|
|
58
|
+
orm_resource = await get_resource(kbid, rid)
|
|
59
|
+
if orm_resource is None:
|
|
60
|
+
return None
|
|
61
|
+
field_obj = await orm_resource.get_field(
|
|
62
|
+
key=field_id.key,
|
|
63
|
+
type=field_id.pb_type,
|
|
64
|
+
load=False,
|
|
65
|
+
)
|
|
66
|
+
return field_obj
|
|
59
67
|
|
|
60
68
|
|
|
61
|
-
async def get_field_extracted_text(field: Field) ->
|
|
69
|
+
async def get_field_extracted_text(field: Field) -> ExtractedText | None:
|
|
62
70
|
if field.extracted_text is not None:
|
|
63
71
|
return field.extracted_text
|
|
64
72
|
|
|
@@ -72,19 +80,6 @@ async def get_field_extracted_text(field: Field) -> Optional[ExtractedText]:
|
|
|
72
80
|
return extracted_text
|
|
73
81
|
|
|
74
82
|
|
|
75
|
-
async def get_extracted_text_from_field_id(kbid: str, field: FieldId) -> Optional[ExtractedText]:
|
|
76
|
-
rid = field.rid
|
|
77
|
-
orm_resource = await get_resource(kbid, rid)
|
|
78
|
-
if orm_resource is None:
|
|
79
|
-
return None
|
|
80
|
-
field_obj = await orm_resource.get_field(
|
|
81
|
-
key=field.key,
|
|
82
|
-
type=field.pb_type,
|
|
83
|
-
load=False,
|
|
84
|
-
)
|
|
85
|
-
return await get_field_extracted_text(field_obj)
|
|
86
|
-
|
|
87
|
-
|
|
88
83
|
@contextlib.contextmanager
|
|
89
84
|
def request_caches():
|
|
90
85
|
"""
|
|
@@ -96,7 +91,8 @@ def request_caches():
|
|
|
96
91
|
Makes sure to clean the caches at the end of the context manager.
|
|
97
92
|
>>> with request_caches():
|
|
98
93
|
... resource = await get_resource(kbid, uuid)
|
|
99
|
-
...
|
|
94
|
+
... field = await get_field(kbid, field_id)
|
|
95
|
+
... extracted_text = await get_field_extracted_text(field)
|
|
100
96
|
"""
|
|
101
97
|
|
|
102
98
|
# This cache size is an arbitrary number, once we have a metric in place and
|