nucliadb 6.7.2.post4874__py3-none-any.whl → 6.10.0.post5705__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migrations/0023_backfill_pg_catalog.py +8 -4
- migrations/0028_extracted_vectors_reference.py +1 -1
- migrations/0029_backfill_field_status.py +3 -4
- migrations/0032_remove_old_relations.py +2 -3
- migrations/0038_backfill_catalog_field_labels.py +8 -4
- migrations/0039_backfill_converation_splits_metadata.py +106 -0
- migrations/0040_migrate_search_configurations.py +79 -0
- migrations/0041_reindex_conversations.py +137 -0
- migrations/pg/0010_shards_index.py +34 -0
- nucliadb/search/api/v1/resource/utils.py → migrations/pg/0011_catalog_statistics.py +5 -6
- migrations/pg/0012_catalog_statistics_undo.py +26 -0
- nucliadb/backups/create.py +2 -15
- nucliadb/backups/restore.py +4 -15
- nucliadb/backups/tasks.py +4 -1
- nucliadb/common/back_pressure/cache.py +2 -3
- nucliadb/common/back_pressure/materializer.py +7 -13
- nucliadb/common/back_pressure/settings.py +6 -6
- nucliadb/common/back_pressure/utils.py +1 -0
- nucliadb/common/cache.py +9 -9
- nucliadb/common/catalog/__init__.py +79 -0
- nucliadb/common/catalog/dummy.py +36 -0
- nucliadb/common/catalog/interface.py +85 -0
- nucliadb/{search/search/pgcatalog.py → common/catalog/pg.py} +330 -232
- nucliadb/common/catalog/utils.py +56 -0
- nucliadb/common/cluster/manager.py +8 -23
- nucliadb/common/cluster/rebalance.py +484 -112
- nucliadb/common/cluster/rollover.py +36 -9
- nucliadb/common/cluster/settings.py +4 -9
- nucliadb/common/cluster/utils.py +34 -8
- nucliadb/common/context/__init__.py +7 -8
- nucliadb/common/context/fastapi.py +1 -2
- nucliadb/common/datamanagers/__init__.py +2 -4
- nucliadb/common/datamanagers/atomic.py +9 -2
- nucliadb/common/datamanagers/cluster.py +1 -2
- nucliadb/common/datamanagers/fields.py +3 -4
- nucliadb/common/datamanagers/kb.py +6 -6
- nucliadb/common/datamanagers/labels.py +2 -3
- nucliadb/common/datamanagers/resources.py +10 -33
- nucliadb/common/datamanagers/rollover.py +5 -7
- nucliadb/common/datamanagers/search_configurations.py +1 -2
- nucliadb/common/datamanagers/synonyms.py +1 -2
- nucliadb/common/datamanagers/utils.py +4 -4
- nucliadb/common/datamanagers/vectorsets.py +4 -4
- nucliadb/common/external_index_providers/base.py +32 -5
- nucliadb/common/external_index_providers/manager.py +5 -34
- nucliadb/common/external_index_providers/settings.py +1 -27
- nucliadb/common/filter_expression.py +129 -41
- nucliadb/common/http_clients/exceptions.py +8 -0
- nucliadb/common/http_clients/processing.py +16 -23
- nucliadb/common/http_clients/utils.py +3 -0
- nucliadb/common/ids.py +82 -58
- nucliadb/common/locking.py +1 -2
- nucliadb/common/maindb/driver.py +9 -8
- nucliadb/common/maindb/local.py +5 -5
- nucliadb/common/maindb/pg.py +9 -8
- nucliadb/common/nidx.py +22 -5
- nucliadb/common/vector_index_config.py +1 -1
- nucliadb/export_import/datamanager.py +4 -3
- nucliadb/export_import/exporter.py +11 -19
- nucliadb/export_import/importer.py +13 -6
- nucliadb/export_import/tasks.py +2 -0
- nucliadb/export_import/utils.py +6 -18
- nucliadb/health.py +2 -2
- nucliadb/ingest/app.py +8 -8
- nucliadb/ingest/consumer/consumer.py +8 -10
- nucliadb/ingest/consumer/pull.py +10 -8
- nucliadb/ingest/consumer/service.py +5 -30
- nucliadb/ingest/consumer/shard_creator.py +16 -5
- nucliadb/ingest/consumer/utils.py +1 -1
- nucliadb/ingest/fields/base.py +37 -49
- nucliadb/ingest/fields/conversation.py +55 -9
- nucliadb/ingest/fields/exceptions.py +1 -2
- nucliadb/ingest/fields/file.py +22 -8
- nucliadb/ingest/fields/link.py +7 -7
- nucliadb/ingest/fields/text.py +2 -3
- nucliadb/ingest/orm/brain_v2.py +89 -57
- nucliadb/ingest/orm/broker_message.py +2 -4
- nucliadb/ingest/orm/entities.py +10 -209
- nucliadb/ingest/orm/index_message.py +128 -113
- nucliadb/ingest/orm/knowledgebox.py +91 -59
- nucliadb/ingest/orm/processor/auditing.py +1 -3
- nucliadb/ingest/orm/processor/data_augmentation.py +1 -2
- nucliadb/ingest/orm/processor/processor.py +98 -153
- nucliadb/ingest/orm/processor/sequence_manager.py +1 -2
- nucliadb/ingest/orm/resource.py +82 -71
- nucliadb/ingest/orm/utils.py +1 -1
- nucliadb/ingest/partitions.py +12 -1
- nucliadb/ingest/processing.py +17 -17
- nucliadb/ingest/serialize.py +202 -145
- nucliadb/ingest/service/writer.py +15 -114
- nucliadb/ingest/settings.py +36 -15
- nucliadb/ingest/utils.py +1 -2
- nucliadb/learning_proxy.py +23 -26
- nucliadb/metrics_exporter.py +20 -6
- nucliadb/middleware/__init__.py +82 -1
- nucliadb/migrator/datamanager.py +4 -11
- nucliadb/migrator/migrator.py +1 -2
- nucliadb/migrator/models.py +1 -2
- nucliadb/migrator/settings.py +1 -2
- nucliadb/models/internal/augment.py +614 -0
- nucliadb/models/internal/processing.py +19 -19
- nucliadb/openapi.py +2 -2
- nucliadb/purge/__init__.py +3 -8
- nucliadb/purge/orphan_shards.py +1 -2
- nucliadb/reader/__init__.py +5 -0
- nucliadb/reader/api/models.py +6 -13
- nucliadb/reader/api/v1/download.py +59 -38
- nucliadb/reader/api/v1/export_import.py +4 -4
- nucliadb/reader/api/v1/knowledgebox.py +37 -9
- nucliadb/reader/api/v1/learning_config.py +33 -14
- nucliadb/reader/api/v1/resource.py +61 -9
- nucliadb/reader/api/v1/services.py +18 -14
- nucliadb/reader/app.py +3 -1
- nucliadb/reader/reader/notifications.py +1 -2
- nucliadb/search/api/v1/__init__.py +3 -0
- nucliadb/search/api/v1/ask.py +3 -4
- nucliadb/search/api/v1/augment.py +585 -0
- nucliadb/search/api/v1/catalog.py +15 -19
- nucliadb/search/api/v1/find.py +16 -22
- nucliadb/search/api/v1/hydrate.py +328 -0
- nucliadb/search/api/v1/knowledgebox.py +1 -2
- nucliadb/search/api/v1/predict_proxy.py +1 -2
- nucliadb/search/api/v1/resource/ask.py +28 -8
- nucliadb/search/api/v1/resource/ingestion_agents.py +5 -6
- nucliadb/search/api/v1/resource/search.py +9 -11
- nucliadb/search/api/v1/retrieve.py +130 -0
- nucliadb/search/api/v1/search.py +28 -32
- nucliadb/search/api/v1/suggest.py +11 -14
- nucliadb/search/api/v1/summarize.py +1 -2
- nucliadb/search/api/v1/utils.py +2 -2
- nucliadb/search/app.py +3 -2
- nucliadb/search/augmentor/__init__.py +21 -0
- nucliadb/search/augmentor/augmentor.py +232 -0
- nucliadb/search/augmentor/fields.py +704 -0
- nucliadb/search/augmentor/metrics.py +24 -0
- nucliadb/search/augmentor/paragraphs.py +334 -0
- nucliadb/search/augmentor/resources.py +238 -0
- nucliadb/search/augmentor/utils.py +33 -0
- nucliadb/search/lifecycle.py +3 -1
- nucliadb/search/predict.py +33 -19
- nucliadb/search/predict_models.py +8 -9
- nucliadb/search/requesters/utils.py +11 -10
- nucliadb/search/search/cache.py +19 -42
- nucliadb/search/search/chat/ask.py +131 -59
- nucliadb/search/search/chat/exceptions.py +3 -5
- nucliadb/search/search/chat/fetcher.py +201 -0
- nucliadb/search/search/chat/images.py +6 -4
- nucliadb/search/search/chat/old_prompt.py +1375 -0
- nucliadb/search/search/chat/parser.py +510 -0
- nucliadb/search/search/chat/prompt.py +563 -615
- nucliadb/search/search/chat/query.py +453 -32
- nucliadb/search/search/chat/rpc.py +85 -0
- nucliadb/search/search/fetch.py +3 -4
- nucliadb/search/search/filters.py +8 -11
- nucliadb/search/search/find.py +33 -31
- nucliadb/search/search/find_merge.py +124 -331
- nucliadb/search/search/graph_strategy.py +14 -12
- nucliadb/search/search/hydrator/__init__.py +49 -0
- nucliadb/search/search/hydrator/fields.py +217 -0
- nucliadb/search/search/hydrator/images.py +130 -0
- nucliadb/search/search/hydrator/paragraphs.py +323 -0
- nucliadb/search/search/hydrator/resources.py +60 -0
- nucliadb/search/search/ingestion_agents.py +5 -5
- nucliadb/search/search/merge.py +90 -94
- nucliadb/search/search/metrics.py +24 -7
- nucliadb/search/search/paragraphs.py +7 -9
- nucliadb/search/search/predict_proxy.py +44 -18
- nucliadb/search/search/query.py +14 -86
- nucliadb/search/search/query_parser/fetcher.py +51 -82
- nucliadb/search/search/query_parser/models.py +19 -48
- nucliadb/search/search/query_parser/old_filters.py +20 -19
- nucliadb/search/search/query_parser/parsers/ask.py +5 -6
- nucliadb/search/search/query_parser/parsers/catalog.py +7 -11
- nucliadb/search/search/query_parser/parsers/common.py +21 -13
- nucliadb/search/search/query_parser/parsers/find.py +6 -29
- nucliadb/search/search/query_parser/parsers/graph.py +18 -28
- nucliadb/search/search/query_parser/parsers/retrieve.py +207 -0
- nucliadb/search/search/query_parser/parsers/search.py +15 -56
- nucliadb/search/search/query_parser/parsers/unit_retrieval.py +8 -29
- nucliadb/search/search/rank_fusion.py +18 -13
- nucliadb/search/search/rerankers.py +6 -7
- nucliadb/search/search/retrieval.py +300 -0
- nucliadb/search/search/summarize.py +5 -6
- nucliadb/search/search/utils.py +3 -4
- nucliadb/search/settings.py +1 -2
- nucliadb/standalone/api_router.py +1 -1
- nucliadb/standalone/app.py +4 -3
- nucliadb/standalone/auth.py +5 -6
- nucliadb/standalone/lifecycle.py +2 -2
- nucliadb/standalone/run.py +5 -4
- nucliadb/standalone/settings.py +5 -6
- nucliadb/standalone/versions.py +3 -4
- nucliadb/tasks/consumer.py +13 -8
- nucliadb/tasks/models.py +2 -1
- nucliadb/tasks/producer.py +3 -3
- nucliadb/tasks/retries.py +8 -7
- nucliadb/train/api/utils.py +1 -3
- nucliadb/train/api/v1/shards.py +1 -2
- nucliadb/train/api/v1/trainset.py +1 -2
- nucliadb/train/app.py +1 -1
- nucliadb/train/generator.py +4 -4
- nucliadb/train/generators/field_classifier.py +2 -2
- nucliadb/train/generators/field_streaming.py +6 -6
- nucliadb/train/generators/image_classifier.py +2 -2
- nucliadb/train/generators/paragraph_classifier.py +2 -2
- nucliadb/train/generators/paragraph_streaming.py +2 -2
- nucliadb/train/generators/question_answer_streaming.py +2 -2
- nucliadb/train/generators/sentence_classifier.py +4 -10
- nucliadb/train/generators/token_classifier.py +3 -2
- nucliadb/train/generators/utils.py +6 -5
- nucliadb/train/nodes.py +3 -3
- nucliadb/train/resource.py +6 -8
- nucliadb/train/settings.py +3 -4
- nucliadb/train/types.py +11 -11
- nucliadb/train/upload.py +3 -2
- nucliadb/train/uploader.py +1 -2
- nucliadb/train/utils.py +1 -2
- nucliadb/writer/api/v1/export_import.py +4 -1
- nucliadb/writer/api/v1/field.py +15 -14
- nucliadb/writer/api/v1/knowledgebox.py +18 -56
- nucliadb/writer/api/v1/learning_config.py +5 -4
- nucliadb/writer/api/v1/resource.py +9 -20
- nucliadb/writer/api/v1/services.py +10 -132
- nucliadb/writer/api/v1/upload.py +73 -72
- nucliadb/writer/app.py +8 -2
- nucliadb/writer/resource/basic.py +12 -15
- nucliadb/writer/resource/field.py +43 -5
- nucliadb/writer/resource/origin.py +7 -0
- nucliadb/writer/settings.py +2 -3
- nucliadb/writer/tus/__init__.py +2 -3
- nucliadb/writer/tus/azure.py +5 -7
- nucliadb/writer/tus/dm.py +3 -3
- nucliadb/writer/tus/exceptions.py +3 -4
- nucliadb/writer/tus/gcs.py +15 -22
- nucliadb/writer/tus/s3.py +2 -3
- nucliadb/writer/tus/storage.py +3 -3
- {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/METADATA +10 -11
- nucliadb-6.10.0.post5705.dist-info/RECORD +410 -0
- nucliadb/common/datamanagers/entities.py +0 -139
- nucliadb/common/external_index_providers/pinecone.py +0 -894
- nucliadb/ingest/orm/processor/pgcatalog.py +0 -129
- nucliadb/search/search/hydrator.py +0 -197
- nucliadb-6.7.2.post4874.dist-info/RECORD +0 -383
- {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/WHEEL +0 -0
- {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/top_level.txt +0 -0
nucliadb/search/predict.py
CHANGED
|
@@ -22,9 +22,9 @@ import json
|
|
|
22
22
|
import logging
|
|
23
23
|
import os
|
|
24
24
|
import random
|
|
25
|
+
from collections.abc import AsyncGenerator
|
|
25
26
|
from dataclasses import dataclass
|
|
26
27
|
from enum import Enum
|
|
27
|
-
from typing import AsyncGenerator, Optional
|
|
28
28
|
from unittest.mock import AsyncMock, Mock
|
|
29
29
|
|
|
30
30
|
import aiohttp
|
|
@@ -144,7 +144,7 @@ class AnswerStatusCode(str, Enum):
|
|
|
144
144
|
@dataclass
|
|
145
145
|
class RephraseResponse:
|
|
146
146
|
rephrased_query: str
|
|
147
|
-
use_chat_history:
|
|
147
|
+
use_chat_history: bool | None
|
|
148
148
|
|
|
149
149
|
|
|
150
150
|
async def start_predict_engine():
|
|
@@ -176,18 +176,18 @@ def convert_relations(data: dict[str, list[dict[str, str]]]) -> list[RelationNod
|
|
|
176
176
|
class PredictEngine:
|
|
177
177
|
def __init__(
|
|
178
178
|
self,
|
|
179
|
-
cluster_url:
|
|
180
|
-
public_url:
|
|
181
|
-
nuclia_service_account:
|
|
182
|
-
zone:
|
|
179
|
+
cluster_url: str | None = None,
|
|
180
|
+
public_url: str | None = None,
|
|
181
|
+
nuclia_service_account: str | None = None,
|
|
182
|
+
zone: str | None = None,
|
|
183
183
|
onprem: bool = False,
|
|
184
184
|
local_predict: bool = False,
|
|
185
|
-
local_predict_headers:
|
|
185
|
+
local_predict_headers: dict[str, str] | None = None,
|
|
186
186
|
):
|
|
187
187
|
self.nuclia_service_account = nuclia_service_account
|
|
188
188
|
self.cluster_url = cluster_url
|
|
189
189
|
if public_url is not None:
|
|
190
|
-
self.public_url:
|
|
190
|
+
self.public_url: str | None = public_url.format(zone=zone)
|
|
191
191
|
else:
|
|
192
192
|
self.public_url = None
|
|
193
193
|
self.zone = zone
|
|
@@ -294,7 +294,7 @@ class PredictEngine:
|
|
|
294
294
|
|
|
295
295
|
@predict_observer.wrap({"type": "chat_ndjson"})
|
|
296
296
|
async def chat_query_ndjson(
|
|
297
|
-
self, kbid: str, item: ChatModel, extra_headers:
|
|
297
|
+
self, kbid: str, item: ChatModel, extra_headers: dict[str, str] | None = None
|
|
298
298
|
) -> tuple[str, str, AsyncGenerator[GenerativeChunk, None]]:
|
|
299
299
|
"""
|
|
300
300
|
Chat query using the new stream format
|
|
@@ -383,7 +383,7 @@ class PredictEngine:
|
|
|
383
383
|
|
|
384
384
|
@predict_observer.wrap({"type": "summarize"})
|
|
385
385
|
async def summarize(
|
|
386
|
-
self, kbid: str, item: SummarizeModel, extra_headers:
|
|
386
|
+
self, kbid: str, item: SummarizeModel, extra_headers: dict[str, str] | None = None
|
|
387
387
|
) -> SummarizedResponse:
|
|
388
388
|
try:
|
|
389
389
|
self.check_nua_key_is_configured_for_onprem()
|
|
@@ -447,6 +447,10 @@ class DummyPredictEngine(PredictEngine):
|
|
|
447
447
|
self.cluster_url = "http://localhost:8000"
|
|
448
448
|
self.public_url = "http://localhost:8000"
|
|
449
449
|
self.calls = []
|
|
450
|
+
self.ndjson_reasoning = [
|
|
451
|
+
b'{"chunk": {"type": "reasoning", "text": "dummy "}}\n',
|
|
452
|
+
b'{"chunk": {"type": "reasoning", "text": "reasoning"}}\n',
|
|
453
|
+
]
|
|
450
454
|
self.ndjson_answer = [
|
|
451
455
|
b'{"chunk": {"type": "text", "text": "valid "}}\n',
|
|
452
456
|
b'{"chunk": {"type": "text", "text": "answer "}}\n',
|
|
@@ -477,13 +481,16 @@ class DummyPredictEngine(PredictEngine):
|
|
|
477
481
|
return RephraseResponse(rephrased_query=DUMMY_REPHRASE_QUERY, use_chat_history=None)
|
|
478
482
|
|
|
479
483
|
async def chat_query_ndjson(
|
|
480
|
-
self, kbid: str, item: ChatModel, extra_headers:
|
|
484
|
+
self, kbid: str, item: ChatModel, extra_headers: dict[str, str] | None = None
|
|
481
485
|
) -> tuple[str, str, AsyncGenerator[GenerativeChunk, None]]:
|
|
482
486
|
self.calls.append(("chat_query_ndjson", item))
|
|
483
487
|
|
|
484
488
|
async def generate():
|
|
485
|
-
|
|
486
|
-
|
|
489
|
+
if item.reasoning is not False:
|
|
490
|
+
for chunk in self.ndjson_reasoning:
|
|
491
|
+
yield GenerativeChunk.model_validate_json(chunk)
|
|
492
|
+
for chunk in self.ndjson_answer:
|
|
493
|
+
yield GenerativeChunk.model_validate_json(chunk)
|
|
487
494
|
|
|
488
495
|
return (DUMMY_LEARNING_ID, DUMMY_LEARNING_MODEL, generate())
|
|
489
496
|
|
|
@@ -517,10 +524,17 @@ class DummyPredictEngine(PredictEngine):
|
|
|
517
524
|
timings[vectorset_id] = 0.010
|
|
518
525
|
|
|
519
526
|
# and fake data with the passed one too
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
527
|
+
if item.semantic_models is not None:
|
|
528
|
+
for model in item.semantic_models:
|
|
529
|
+
semantic_thresholds[model] = self.default_semantic_threshold
|
|
530
|
+
vectors[model] = base_vector
|
|
531
|
+
timings[model] = 0.0
|
|
532
|
+
|
|
533
|
+
if len(vectors) == 0:
|
|
534
|
+
model = "<PREDICT-DEFAULT-SEMANTIC-MODEL>"
|
|
535
|
+
semantic_thresholds[model] = self.default_semantic_threshold
|
|
536
|
+
vectors[model] = base_vector
|
|
537
|
+
timings[model] = 0.0
|
|
524
538
|
|
|
525
539
|
return QueryInfo(
|
|
526
540
|
language="en",
|
|
@@ -533,7 +547,7 @@ class DummyPredictEngine(PredictEngine):
|
|
|
533
547
|
vectors=vectors,
|
|
534
548
|
timings=timings,
|
|
535
549
|
),
|
|
536
|
-
query=
|
|
550
|
+
query=item.text or "<PREDICT-QUERY>",
|
|
537
551
|
rephrased_query="<REPHRASED-QUERY>" if item.rephrase or item.query_image else None,
|
|
538
552
|
)
|
|
539
553
|
|
|
@@ -546,7 +560,7 @@ class DummyPredictEngine(PredictEngine):
|
|
|
546
560
|
return DUMMY_RELATION_NODE
|
|
547
561
|
|
|
548
562
|
async def summarize(
|
|
549
|
-
self, kbid: str, item: SummarizeModel, extra_headers:
|
|
563
|
+
self, kbid: str, item: SummarizeModel, extra_headers: dict[str, str] | None = None
|
|
550
564
|
) -> SummarizedResponse:
|
|
551
565
|
self.calls.append(("summarize", (kbid, item)))
|
|
552
566
|
response = SummarizedResponse(
|
|
@@ -19,7 +19,6 @@
|
|
|
19
19
|
|
|
20
20
|
from base64 import b64decode, b64encode
|
|
21
21
|
from enum import Enum
|
|
22
|
-
from typing import Optional
|
|
23
22
|
|
|
24
23
|
from google.protobuf.message import DecodeError, Message
|
|
25
24
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|
@@ -77,7 +76,7 @@ class RunAgentsRequest(BaseModel):
|
|
|
77
76
|
default_factory=list,
|
|
78
77
|
title="An optional list of Data Augmentation Agent IDs to run. If empty, all configured agents that match the filters are run.",
|
|
79
78
|
)
|
|
80
|
-
filters:
|
|
79
|
+
filters: list[NameOperationFilter] | None = Field(
|
|
81
80
|
default=None,
|
|
82
81
|
title="Filters to select which Data Augmentation Agents are applied to the text. If empty, all configured agents for the Knowledge Box are applied.",
|
|
83
82
|
)
|
|
@@ -93,7 +92,7 @@ class AppliedDataAugmentation(BaseModel):
|
|
|
93
92
|
# Since we have protos as fields, we need to enable arbitrary_types_allowed
|
|
94
93
|
arbitrary_types_allowed=True,
|
|
95
94
|
)
|
|
96
|
-
qas:
|
|
95
|
+
qas: QuestionAnswers | None = Field(
|
|
97
96
|
default=None,
|
|
98
97
|
description="Question and answers generated by the Question Answers agent",
|
|
99
98
|
)
|
|
@@ -107,7 +106,7 @@ class AppliedDataAugmentation(BaseModel):
|
|
|
107
106
|
)
|
|
108
107
|
|
|
109
108
|
@field_validator("qas", mode="before")
|
|
110
|
-
def validate_qas(cls, qas:
|
|
109
|
+
def validate_qas(cls, qas: str | None) -> QuestionAnswers | None:
|
|
111
110
|
if qas is None:
|
|
112
111
|
return None
|
|
113
112
|
try:
|
|
@@ -171,8 +170,8 @@ class QueryModel(BaseModel):
|
|
|
171
170
|
Model to represent a query request
|
|
172
171
|
"""
|
|
173
172
|
|
|
174
|
-
text:
|
|
175
|
-
query_image:
|
|
173
|
+
text: str | None = Field(default=None, description="The query text to be processed")
|
|
174
|
+
query_image: Image | None = Field(
|
|
176
175
|
default=None,
|
|
177
176
|
description="Image to be considered as part of the query. Even if the `rephrase` parameter is set to `false`, the rephrasing process will occur, combining the provided text with the image's visual features in the rephrased query.",
|
|
178
177
|
)
|
|
@@ -180,7 +179,7 @@ class QueryModel(BaseModel):
|
|
|
180
179
|
default=False,
|
|
181
180
|
description="If true, the model will rephrase the input text before processing",
|
|
182
181
|
)
|
|
183
|
-
rephrase_prompt:
|
|
182
|
+
rephrase_prompt: str | None = Field(
|
|
184
183
|
default=None,
|
|
185
184
|
description="Custom prompt for rephrasing the input text",
|
|
186
185
|
examples=[
|
|
@@ -192,11 +191,11 @@ QUESTION: {question}
|
|
|
192
191
|
Please return ONLY the question without any explanation.""",
|
|
193
192
|
],
|
|
194
193
|
)
|
|
195
|
-
generative_model:
|
|
194
|
+
generative_model: str | None = Field(
|
|
196
195
|
default=None,
|
|
197
196
|
description="The generative model to use for rephrasing",
|
|
198
197
|
)
|
|
199
|
-
semantic_models:
|
|
198
|
+
semantic_models: list[str] | None = Field(
|
|
200
199
|
default=None,
|
|
201
200
|
description="Semantic models to compute the sentence vector for, if not provided, it will only compute the sentence vector for default semantic model in the Knowledge box's configuration.",
|
|
202
201
|
)
|
|
@@ -19,8 +19,9 @@
|
|
|
19
19
|
|
|
20
20
|
import asyncio
|
|
21
21
|
import json
|
|
22
|
+
from collections.abc import Sequence
|
|
22
23
|
from enum import Enum, auto
|
|
23
|
-
from typing import
|
|
24
|
+
from typing import TypeVar, overload
|
|
24
25
|
|
|
25
26
|
from fastapi import HTTPException
|
|
26
27
|
from google.protobuf.json_format import MessageToDict
|
|
@@ -60,7 +61,7 @@ METHODS = {
|
|
|
60
61
|
Method.GRAPH: graph_search_shard,
|
|
61
62
|
}
|
|
62
63
|
|
|
63
|
-
REQUEST_TYPE =
|
|
64
|
+
REQUEST_TYPE = SuggestRequest | SearchRequest | GraphSearchRequest
|
|
64
65
|
|
|
65
66
|
T = TypeVar(
|
|
66
67
|
"T",
|
|
@@ -75,7 +76,7 @@ async def nidx_query(
|
|
|
75
76
|
kbid: str,
|
|
76
77
|
method: Method,
|
|
77
78
|
pb_query: SuggestRequest,
|
|
78
|
-
timeout:
|
|
79
|
+
timeout: float | None = None,
|
|
79
80
|
) -> tuple[list[SuggestResponse], list[str]]: ...
|
|
80
81
|
|
|
81
82
|
|
|
@@ -84,7 +85,7 @@ async def nidx_query(
|
|
|
84
85
|
kbid: str,
|
|
85
86
|
method: Method,
|
|
86
87
|
pb_query: SearchRequest,
|
|
87
|
-
timeout:
|
|
88
|
+
timeout: float | None = None,
|
|
88
89
|
) -> tuple[list[SearchResponse], list[str]]: ...
|
|
89
90
|
|
|
90
91
|
|
|
@@ -93,7 +94,7 @@ async def nidx_query(
|
|
|
93
94
|
kbid: str,
|
|
94
95
|
method: Method,
|
|
95
96
|
pb_query: GraphSearchRequest,
|
|
96
|
-
timeout:
|
|
97
|
+
timeout: float | None = None,
|
|
97
98
|
) -> tuple[list[GraphSearchResponse], list[str]]: ...
|
|
98
99
|
|
|
99
100
|
|
|
@@ -101,8 +102,8 @@ async def nidx_query(
|
|
|
101
102
|
kbid: str,
|
|
102
103
|
method: Method,
|
|
103
104
|
pb_query: REQUEST_TYPE,
|
|
104
|
-
timeout:
|
|
105
|
-
) -> tuple[Sequence[
|
|
105
|
+
timeout: float | None = None,
|
|
106
|
+
) -> tuple[Sequence[T | BaseException], list[str]]:
|
|
106
107
|
timeout = timeout or settings.search_timeout
|
|
107
108
|
shard_manager = get_shard_manager()
|
|
108
109
|
try:
|
|
@@ -133,7 +134,7 @@ async def nidx_query(
|
|
|
133
134
|
)
|
|
134
135
|
|
|
135
136
|
try:
|
|
136
|
-
results: list[
|
|
137
|
+
results: list[T | BaseException] = await asyncio.wait_for(
|
|
137
138
|
asyncio.gather(*ops, return_exceptions=True),
|
|
138
139
|
timeout=timeout,
|
|
139
140
|
)
|
|
@@ -159,13 +160,13 @@ async def nidx_query(
|
|
|
159
160
|
return results, queried_shards
|
|
160
161
|
|
|
161
162
|
|
|
162
|
-
def validate_nidx_query_results(results: list[
|
|
163
|
+
def validate_nidx_query_results(results: list[T | BaseException]) -> HTTPException | None:
|
|
163
164
|
"""
|
|
164
165
|
Validate the results of a nidx query and return an exception if any error is found
|
|
165
166
|
|
|
166
167
|
Handling of exception is responsibility of caller.
|
|
167
168
|
"""
|
|
168
|
-
if
|
|
169
|
+
if len(results) == 0:
|
|
169
170
|
return HTTPException(status_code=500, detail=f"Error while executing shard queries. No results.")
|
|
170
171
|
|
|
171
172
|
for result in results:
|
nucliadb/search/search/cache.py
CHANGED
|
@@ -19,9 +19,6 @@
|
|
|
19
19
|
|
|
20
20
|
import contextlib
|
|
21
21
|
import logging
|
|
22
|
-
from typing import Optional
|
|
23
|
-
|
|
24
|
-
import backoff
|
|
25
22
|
|
|
26
23
|
from nucliadb.common.cache import (
|
|
27
24
|
extracted_text_cache,
|
|
@@ -41,26 +38,35 @@ from nucliadb_utils.utilities import get_storage
|
|
|
41
38
|
logger = logging.getLogger(__name__)
|
|
42
39
|
|
|
43
40
|
|
|
44
|
-
async def get_resource(kbid: str, uuid: str) ->
|
|
41
|
+
async def get_resource(kbid: str, uuid: str) -> ResourceORM | None:
|
|
45
42
|
"""
|
|
46
43
|
Will try to get the resource from the cache, if it's not there it will fetch it from the ORM and cache it.
|
|
47
44
|
"""
|
|
48
45
|
resource_cache = get_resource_cache()
|
|
49
46
|
if resource_cache is None:
|
|
50
47
|
logger.warning("Resource cache not set")
|
|
51
|
-
|
|
48
|
+
async with get_driver().ro_transaction() as txn:
|
|
49
|
+
storage = await get_storage(service_name=SERVICE_NAME)
|
|
50
|
+
kb = KnowledgeBoxORM(txn, storage, kbid)
|
|
51
|
+
return await kb.get(uuid)
|
|
52
52
|
|
|
53
53
|
return await resource_cache.get(kbid, uuid)
|
|
54
54
|
|
|
55
55
|
|
|
56
|
-
async def
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
return
|
|
56
|
+
async def get_field(kbid: str, field_id: FieldId) -> Field | None:
|
|
57
|
+
rid = field_id.rid
|
|
58
|
+
orm_resource = await get_resource(kbid, rid)
|
|
59
|
+
if orm_resource is None:
|
|
60
|
+
return None
|
|
61
|
+
field_obj = await orm_resource.get_field(
|
|
62
|
+
key=field_id.key,
|
|
63
|
+
type=field_id.pb_type,
|
|
64
|
+
load=False,
|
|
65
|
+
)
|
|
66
|
+
return field_obj
|
|
61
67
|
|
|
62
68
|
|
|
63
|
-
async def get_field_extracted_text(field: Field) ->
|
|
69
|
+
async def get_field_extracted_text(field: Field) -> ExtractedText | None:
|
|
64
70
|
if field.extracted_text is not None:
|
|
65
71
|
return field.extracted_text
|
|
66
72
|
|
|
@@ -74,36 +80,6 @@ async def get_field_extracted_text(field: Field) -> Optional[ExtractedText]:
|
|
|
74
80
|
return extracted_text
|
|
75
81
|
|
|
76
82
|
|
|
77
|
-
@backoff.on_exception(backoff.expo, (Exception,), jitter=backoff.random_jitter, max_tries=3)
|
|
78
|
-
async def field_get_extracted_text(field: Field) -> Optional[ExtractedText]:
|
|
79
|
-
try:
|
|
80
|
-
return await field.get_extracted_text()
|
|
81
|
-
except Exception:
|
|
82
|
-
logger.warning(
|
|
83
|
-
"Error getting extracted text for field. Retrying",
|
|
84
|
-
exc_info=True,
|
|
85
|
-
extra={
|
|
86
|
-
"kbid": field.kbid,
|
|
87
|
-
"resource_id": field.resource.uuid,
|
|
88
|
-
"field": f"{field.type}/{field.id}",
|
|
89
|
-
},
|
|
90
|
-
)
|
|
91
|
-
raise
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
async def get_extracted_text_from_field_id(kbid: str, field: FieldId) -> Optional[ExtractedText]:
|
|
95
|
-
rid = field.rid
|
|
96
|
-
orm_resource = await get_resource(kbid, rid)
|
|
97
|
-
if orm_resource is None:
|
|
98
|
-
return None
|
|
99
|
-
field_obj = await orm_resource.get_field(
|
|
100
|
-
key=field.key,
|
|
101
|
-
type=field.pb_type,
|
|
102
|
-
load=False,
|
|
103
|
-
)
|
|
104
|
-
return await get_field_extracted_text(field_obj)
|
|
105
|
-
|
|
106
|
-
|
|
107
83
|
@contextlib.contextmanager
|
|
108
84
|
def request_caches():
|
|
109
85
|
"""
|
|
@@ -115,7 +91,8 @@ def request_caches():
|
|
|
115
91
|
Makes sure to clean the caches at the end of the context manager.
|
|
116
92
|
>>> with request_caches():
|
|
117
93
|
... resource = await get_resource(kbid, uuid)
|
|
118
|
-
...
|
|
94
|
+
... field = await get_field(kbid, field_id)
|
|
95
|
+
... extracted_text = await get_field_extracted_text(field)
|
|
119
96
|
"""
|
|
120
97
|
|
|
121
98
|
# This cache size is an arbitrary number, once we have a metric in place and
|