nucliadb 6.7.2.post4874__py3-none-any.whl → 6.10.0.post5705__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migrations/0023_backfill_pg_catalog.py +8 -4
- migrations/0028_extracted_vectors_reference.py +1 -1
- migrations/0029_backfill_field_status.py +3 -4
- migrations/0032_remove_old_relations.py +2 -3
- migrations/0038_backfill_catalog_field_labels.py +8 -4
- migrations/0039_backfill_converation_splits_metadata.py +106 -0
- migrations/0040_migrate_search_configurations.py +79 -0
- migrations/0041_reindex_conversations.py +137 -0
- migrations/pg/0010_shards_index.py +34 -0
- nucliadb/search/api/v1/resource/utils.py → migrations/pg/0011_catalog_statistics.py +5 -6
- migrations/pg/0012_catalog_statistics_undo.py +26 -0
- nucliadb/backups/create.py +2 -15
- nucliadb/backups/restore.py +4 -15
- nucliadb/backups/tasks.py +4 -1
- nucliadb/common/back_pressure/cache.py +2 -3
- nucliadb/common/back_pressure/materializer.py +7 -13
- nucliadb/common/back_pressure/settings.py +6 -6
- nucliadb/common/back_pressure/utils.py +1 -0
- nucliadb/common/cache.py +9 -9
- nucliadb/common/catalog/__init__.py +79 -0
- nucliadb/common/catalog/dummy.py +36 -0
- nucliadb/common/catalog/interface.py +85 -0
- nucliadb/{search/search/pgcatalog.py → common/catalog/pg.py} +330 -232
- nucliadb/common/catalog/utils.py +56 -0
- nucliadb/common/cluster/manager.py +8 -23
- nucliadb/common/cluster/rebalance.py +484 -112
- nucliadb/common/cluster/rollover.py +36 -9
- nucliadb/common/cluster/settings.py +4 -9
- nucliadb/common/cluster/utils.py +34 -8
- nucliadb/common/context/__init__.py +7 -8
- nucliadb/common/context/fastapi.py +1 -2
- nucliadb/common/datamanagers/__init__.py +2 -4
- nucliadb/common/datamanagers/atomic.py +9 -2
- nucliadb/common/datamanagers/cluster.py +1 -2
- nucliadb/common/datamanagers/fields.py +3 -4
- nucliadb/common/datamanagers/kb.py +6 -6
- nucliadb/common/datamanagers/labels.py +2 -3
- nucliadb/common/datamanagers/resources.py +10 -33
- nucliadb/common/datamanagers/rollover.py +5 -7
- nucliadb/common/datamanagers/search_configurations.py +1 -2
- nucliadb/common/datamanagers/synonyms.py +1 -2
- nucliadb/common/datamanagers/utils.py +4 -4
- nucliadb/common/datamanagers/vectorsets.py +4 -4
- nucliadb/common/external_index_providers/base.py +32 -5
- nucliadb/common/external_index_providers/manager.py +5 -34
- nucliadb/common/external_index_providers/settings.py +1 -27
- nucliadb/common/filter_expression.py +129 -41
- nucliadb/common/http_clients/exceptions.py +8 -0
- nucliadb/common/http_clients/processing.py +16 -23
- nucliadb/common/http_clients/utils.py +3 -0
- nucliadb/common/ids.py +82 -58
- nucliadb/common/locking.py +1 -2
- nucliadb/common/maindb/driver.py +9 -8
- nucliadb/common/maindb/local.py +5 -5
- nucliadb/common/maindb/pg.py +9 -8
- nucliadb/common/nidx.py +22 -5
- nucliadb/common/vector_index_config.py +1 -1
- nucliadb/export_import/datamanager.py +4 -3
- nucliadb/export_import/exporter.py +11 -19
- nucliadb/export_import/importer.py +13 -6
- nucliadb/export_import/tasks.py +2 -0
- nucliadb/export_import/utils.py +6 -18
- nucliadb/health.py +2 -2
- nucliadb/ingest/app.py +8 -8
- nucliadb/ingest/consumer/consumer.py +8 -10
- nucliadb/ingest/consumer/pull.py +10 -8
- nucliadb/ingest/consumer/service.py +5 -30
- nucliadb/ingest/consumer/shard_creator.py +16 -5
- nucliadb/ingest/consumer/utils.py +1 -1
- nucliadb/ingest/fields/base.py +37 -49
- nucliadb/ingest/fields/conversation.py +55 -9
- nucliadb/ingest/fields/exceptions.py +1 -2
- nucliadb/ingest/fields/file.py +22 -8
- nucliadb/ingest/fields/link.py +7 -7
- nucliadb/ingest/fields/text.py +2 -3
- nucliadb/ingest/orm/brain_v2.py +89 -57
- nucliadb/ingest/orm/broker_message.py +2 -4
- nucliadb/ingest/orm/entities.py +10 -209
- nucliadb/ingest/orm/index_message.py +128 -113
- nucliadb/ingest/orm/knowledgebox.py +91 -59
- nucliadb/ingest/orm/processor/auditing.py +1 -3
- nucliadb/ingest/orm/processor/data_augmentation.py +1 -2
- nucliadb/ingest/orm/processor/processor.py +98 -153
- nucliadb/ingest/orm/processor/sequence_manager.py +1 -2
- nucliadb/ingest/orm/resource.py +82 -71
- nucliadb/ingest/orm/utils.py +1 -1
- nucliadb/ingest/partitions.py +12 -1
- nucliadb/ingest/processing.py +17 -17
- nucliadb/ingest/serialize.py +202 -145
- nucliadb/ingest/service/writer.py +15 -114
- nucliadb/ingest/settings.py +36 -15
- nucliadb/ingest/utils.py +1 -2
- nucliadb/learning_proxy.py +23 -26
- nucliadb/metrics_exporter.py +20 -6
- nucliadb/middleware/__init__.py +82 -1
- nucliadb/migrator/datamanager.py +4 -11
- nucliadb/migrator/migrator.py +1 -2
- nucliadb/migrator/models.py +1 -2
- nucliadb/migrator/settings.py +1 -2
- nucliadb/models/internal/augment.py +614 -0
- nucliadb/models/internal/processing.py +19 -19
- nucliadb/openapi.py +2 -2
- nucliadb/purge/__init__.py +3 -8
- nucliadb/purge/orphan_shards.py +1 -2
- nucliadb/reader/__init__.py +5 -0
- nucliadb/reader/api/models.py +6 -13
- nucliadb/reader/api/v1/download.py +59 -38
- nucliadb/reader/api/v1/export_import.py +4 -4
- nucliadb/reader/api/v1/knowledgebox.py +37 -9
- nucliadb/reader/api/v1/learning_config.py +33 -14
- nucliadb/reader/api/v1/resource.py +61 -9
- nucliadb/reader/api/v1/services.py +18 -14
- nucliadb/reader/app.py +3 -1
- nucliadb/reader/reader/notifications.py +1 -2
- nucliadb/search/api/v1/__init__.py +3 -0
- nucliadb/search/api/v1/ask.py +3 -4
- nucliadb/search/api/v1/augment.py +585 -0
- nucliadb/search/api/v1/catalog.py +15 -19
- nucliadb/search/api/v1/find.py +16 -22
- nucliadb/search/api/v1/hydrate.py +328 -0
- nucliadb/search/api/v1/knowledgebox.py +1 -2
- nucliadb/search/api/v1/predict_proxy.py +1 -2
- nucliadb/search/api/v1/resource/ask.py +28 -8
- nucliadb/search/api/v1/resource/ingestion_agents.py +5 -6
- nucliadb/search/api/v1/resource/search.py +9 -11
- nucliadb/search/api/v1/retrieve.py +130 -0
- nucliadb/search/api/v1/search.py +28 -32
- nucliadb/search/api/v1/suggest.py +11 -14
- nucliadb/search/api/v1/summarize.py +1 -2
- nucliadb/search/api/v1/utils.py +2 -2
- nucliadb/search/app.py +3 -2
- nucliadb/search/augmentor/__init__.py +21 -0
- nucliadb/search/augmentor/augmentor.py +232 -0
- nucliadb/search/augmentor/fields.py +704 -0
- nucliadb/search/augmentor/metrics.py +24 -0
- nucliadb/search/augmentor/paragraphs.py +334 -0
- nucliadb/search/augmentor/resources.py +238 -0
- nucliadb/search/augmentor/utils.py +33 -0
- nucliadb/search/lifecycle.py +3 -1
- nucliadb/search/predict.py +33 -19
- nucliadb/search/predict_models.py +8 -9
- nucliadb/search/requesters/utils.py +11 -10
- nucliadb/search/search/cache.py +19 -42
- nucliadb/search/search/chat/ask.py +131 -59
- nucliadb/search/search/chat/exceptions.py +3 -5
- nucliadb/search/search/chat/fetcher.py +201 -0
- nucliadb/search/search/chat/images.py +6 -4
- nucliadb/search/search/chat/old_prompt.py +1375 -0
- nucliadb/search/search/chat/parser.py +510 -0
- nucliadb/search/search/chat/prompt.py +563 -615
- nucliadb/search/search/chat/query.py +453 -32
- nucliadb/search/search/chat/rpc.py +85 -0
- nucliadb/search/search/fetch.py +3 -4
- nucliadb/search/search/filters.py +8 -11
- nucliadb/search/search/find.py +33 -31
- nucliadb/search/search/find_merge.py +124 -331
- nucliadb/search/search/graph_strategy.py +14 -12
- nucliadb/search/search/hydrator/__init__.py +49 -0
- nucliadb/search/search/hydrator/fields.py +217 -0
- nucliadb/search/search/hydrator/images.py +130 -0
- nucliadb/search/search/hydrator/paragraphs.py +323 -0
- nucliadb/search/search/hydrator/resources.py +60 -0
- nucliadb/search/search/ingestion_agents.py +5 -5
- nucliadb/search/search/merge.py +90 -94
- nucliadb/search/search/metrics.py +24 -7
- nucliadb/search/search/paragraphs.py +7 -9
- nucliadb/search/search/predict_proxy.py +44 -18
- nucliadb/search/search/query.py +14 -86
- nucliadb/search/search/query_parser/fetcher.py +51 -82
- nucliadb/search/search/query_parser/models.py +19 -48
- nucliadb/search/search/query_parser/old_filters.py +20 -19
- nucliadb/search/search/query_parser/parsers/ask.py +5 -6
- nucliadb/search/search/query_parser/parsers/catalog.py +7 -11
- nucliadb/search/search/query_parser/parsers/common.py +21 -13
- nucliadb/search/search/query_parser/parsers/find.py +6 -29
- nucliadb/search/search/query_parser/parsers/graph.py +18 -28
- nucliadb/search/search/query_parser/parsers/retrieve.py +207 -0
- nucliadb/search/search/query_parser/parsers/search.py +15 -56
- nucliadb/search/search/query_parser/parsers/unit_retrieval.py +8 -29
- nucliadb/search/search/rank_fusion.py +18 -13
- nucliadb/search/search/rerankers.py +6 -7
- nucliadb/search/search/retrieval.py +300 -0
- nucliadb/search/search/summarize.py +5 -6
- nucliadb/search/search/utils.py +3 -4
- nucliadb/search/settings.py +1 -2
- nucliadb/standalone/api_router.py +1 -1
- nucliadb/standalone/app.py +4 -3
- nucliadb/standalone/auth.py +5 -6
- nucliadb/standalone/lifecycle.py +2 -2
- nucliadb/standalone/run.py +5 -4
- nucliadb/standalone/settings.py +5 -6
- nucliadb/standalone/versions.py +3 -4
- nucliadb/tasks/consumer.py +13 -8
- nucliadb/tasks/models.py +2 -1
- nucliadb/tasks/producer.py +3 -3
- nucliadb/tasks/retries.py +8 -7
- nucliadb/train/api/utils.py +1 -3
- nucliadb/train/api/v1/shards.py +1 -2
- nucliadb/train/api/v1/trainset.py +1 -2
- nucliadb/train/app.py +1 -1
- nucliadb/train/generator.py +4 -4
- nucliadb/train/generators/field_classifier.py +2 -2
- nucliadb/train/generators/field_streaming.py +6 -6
- nucliadb/train/generators/image_classifier.py +2 -2
- nucliadb/train/generators/paragraph_classifier.py +2 -2
- nucliadb/train/generators/paragraph_streaming.py +2 -2
- nucliadb/train/generators/question_answer_streaming.py +2 -2
- nucliadb/train/generators/sentence_classifier.py +4 -10
- nucliadb/train/generators/token_classifier.py +3 -2
- nucliadb/train/generators/utils.py +6 -5
- nucliadb/train/nodes.py +3 -3
- nucliadb/train/resource.py +6 -8
- nucliadb/train/settings.py +3 -4
- nucliadb/train/types.py +11 -11
- nucliadb/train/upload.py +3 -2
- nucliadb/train/uploader.py +1 -2
- nucliadb/train/utils.py +1 -2
- nucliadb/writer/api/v1/export_import.py +4 -1
- nucliadb/writer/api/v1/field.py +15 -14
- nucliadb/writer/api/v1/knowledgebox.py +18 -56
- nucliadb/writer/api/v1/learning_config.py +5 -4
- nucliadb/writer/api/v1/resource.py +9 -20
- nucliadb/writer/api/v1/services.py +10 -132
- nucliadb/writer/api/v1/upload.py +73 -72
- nucliadb/writer/app.py +8 -2
- nucliadb/writer/resource/basic.py +12 -15
- nucliadb/writer/resource/field.py +43 -5
- nucliadb/writer/resource/origin.py +7 -0
- nucliadb/writer/settings.py +2 -3
- nucliadb/writer/tus/__init__.py +2 -3
- nucliadb/writer/tus/azure.py +5 -7
- nucliadb/writer/tus/dm.py +3 -3
- nucliadb/writer/tus/exceptions.py +3 -4
- nucliadb/writer/tus/gcs.py +15 -22
- nucliadb/writer/tus/s3.py +2 -3
- nucliadb/writer/tus/storage.py +3 -3
- {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/METADATA +10 -11
- nucliadb-6.10.0.post5705.dist-info/RECORD +410 -0
- nucliadb/common/datamanagers/entities.py +0 -139
- nucliadb/common/external_index_providers/pinecone.py +0 -894
- nucliadb/ingest/orm/processor/pgcatalog.py +0 -129
- nucliadb/search/search/hydrator.py +0 -197
- nucliadb-6.7.2.post4874.dist-info/RECORD +0 -383
- {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/WHEEL +0 -0
- {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.7.2.post4874.dist-info → nucliadb-6.10.0.post5705.dist-info}/top_level.txt +0 -0
|
@@ -18,7 +18,6 @@
|
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
20
|
import asyncio
|
|
21
|
-
from typing import Optional
|
|
22
21
|
|
|
23
22
|
from nucliadb.common import datamanagers
|
|
24
23
|
from nucliadb.common.maindb.utils import get_driver
|
|
@@ -36,7 +35,7 @@ from nucliadb_models.search import (
|
|
|
36
35
|
from nucliadb_protos.utils_pb2 import ExtractedText
|
|
37
36
|
from nucliadb_utils.utilities import get_storage
|
|
38
37
|
|
|
39
|
-
ExtractedTexts = list[tuple[str, str,
|
|
38
|
+
ExtractedTexts = list[tuple[str, str, ExtractedText | None]]
|
|
40
39
|
|
|
41
40
|
MAX_GET_EXTRACTED_TEXT_OPS = 20
|
|
42
41
|
|
|
@@ -46,7 +45,7 @@ class NoResourcesToSummarize(Exception):
|
|
|
46
45
|
|
|
47
46
|
|
|
48
47
|
async def summarize(
|
|
49
|
-
kbid: str, request: SummarizeRequest, extra_predict_headers:
|
|
48
|
+
kbid: str, request: SummarizeRequest, extra_predict_headers: dict[str, str] | None
|
|
50
49
|
) -> SummarizedResponse:
|
|
51
50
|
predict_request = SummarizeModel()
|
|
52
51
|
predict_request.generative_model = request.generative_model
|
|
@@ -87,7 +86,7 @@ async def get_extracted_texts(kbid: str, resource_uuids_or_slugs: list[str]) ->
|
|
|
87
86
|
if uuid is None:
|
|
88
87
|
logger.warning(f"Resource {uuid_or_slug} not found in KB", extra={"kbid": kbid})
|
|
89
88
|
continue
|
|
90
|
-
resource_orm = Resource(txn=txn, storage=storage,
|
|
89
|
+
resource_orm = Resource(txn=txn, storage=storage, kbid=kbid, uuid=uuid)
|
|
91
90
|
fields = await resource_orm.get_fields(force=True)
|
|
92
91
|
for _, field in fields.items():
|
|
93
92
|
task = asyncio.create_task(get_extracted_text(uuid_or_slug, field, max_tasks))
|
|
@@ -115,14 +114,14 @@ async def get_extracted_texts(kbid: str, resource_uuids_or_slugs: list[str]) ->
|
|
|
115
114
|
|
|
116
115
|
async def get_extracted_text(
|
|
117
116
|
uuid_or_slug, field: Field, max_operations: asyncio.Semaphore
|
|
118
|
-
) -> tuple[str, str,
|
|
117
|
+
) -> tuple[str, str, ExtractedText | None]:
|
|
119
118
|
async with max_operations:
|
|
120
119
|
extracted_text = await field.get_extracted_text(force=True)
|
|
121
120
|
field_key = f"{field.type}/{field.id}"
|
|
122
121
|
return uuid_or_slug, field_key, extracted_text
|
|
123
122
|
|
|
124
123
|
|
|
125
|
-
async def get_resource_uuid(kbobj: KnowledgeBox, uuid_or_slug: str) ->
|
|
124
|
+
async def get_resource_uuid(kbobj: KnowledgeBox, uuid_or_slug: str) -> str | None:
|
|
126
125
|
"""
|
|
127
126
|
Return the uuid of the resource with the given uuid_or_slug.
|
|
128
127
|
"""
|
nucliadb/search/search/utils.py
CHANGED
|
@@ -18,7 +18,6 @@
|
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
20
|
import logging
|
|
21
|
-
from typing import Optional
|
|
22
21
|
|
|
23
22
|
from pydantic import BaseModel
|
|
24
23
|
|
|
@@ -30,7 +29,7 @@ from nucliadb_utils.utilities import has_feature
|
|
|
30
29
|
logger = logging.getLogger(__name__)
|
|
31
30
|
|
|
32
31
|
|
|
33
|
-
async def filter_hidden_resources(kbid: str, show_hidden: bool) ->
|
|
32
|
+
async def filter_hidden_resources(kbid: str, show_hidden: bool) -> bool | None:
|
|
34
33
|
kb_config = await kb.get_config(kbid=kbid)
|
|
35
34
|
hidden_enabled = kb_config and kb_config.hidden_resources_enabled
|
|
36
35
|
if hidden_enabled and not show_hidden:
|
|
@@ -41,8 +40,8 @@ async def filter_hidden_resources(kbid: str, show_hidden: bool) -> Optional[bool
|
|
|
41
40
|
|
|
42
41
|
def min_score_from_query_params(
|
|
43
42
|
min_score_bm25: float,
|
|
44
|
-
min_score_semantic:
|
|
45
|
-
deprecated_min_score:
|
|
43
|
+
min_score_semantic: float | None,
|
|
44
|
+
deprecated_min_score: float | None,
|
|
46
45
|
) -> MinScore:
|
|
47
46
|
# Keep backward compatibility with the deprecated min_score parameter
|
|
48
47
|
semantic = deprecated_min_score if min_score_semantic is None else min_score_semantic
|
nucliadb/search/settings.py
CHANGED
|
@@ -18,7 +18,6 @@
|
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
20
|
|
|
21
|
-
from typing import Optional
|
|
22
21
|
|
|
23
22
|
from pydantic import Field
|
|
24
23
|
|
|
@@ -43,7 +42,7 @@ class Settings(DriverSettings):
|
|
|
43
42
|
title="Prequeries max parallel",
|
|
44
43
|
description="The maximum number of prequeries to run in parallel per /ask request",
|
|
45
44
|
)
|
|
46
|
-
nidx_address:
|
|
45
|
+
nidx_address: str | None = Field(default=None)
|
|
47
46
|
|
|
48
47
|
|
|
49
48
|
settings = Settings()
|
|
@@ -57,7 +57,7 @@ async def api_config_check(request: Request):
|
|
|
57
57
|
valid_nua_key = True
|
|
58
58
|
except Exception as exc:
|
|
59
59
|
logger.warning(f"Error validating nua key", exc_info=exc)
|
|
60
|
-
nua_key_check_error = f"Error checking NUA key: {
|
|
60
|
+
nua_key_check_error = f"Error checking NUA key: {exc!s}"
|
|
61
61
|
return JSONResponse(
|
|
62
62
|
{
|
|
63
63
|
"nua_api_key": {
|
nucliadb/standalone/app.py
CHANGED
|
@@ -31,7 +31,7 @@ from starlette.responses import HTMLResponse
|
|
|
31
31
|
from starlette.routing import Mount
|
|
32
32
|
|
|
33
33
|
import nucliadb_admin_assets # type: ignore
|
|
34
|
-
from nucliadb.middleware import ProcessTimeHeaderMiddleware
|
|
34
|
+
from nucliadb.middleware import ClientErrorPayloadLoggerMiddleware, ProcessTimeHeaderMiddleware
|
|
35
35
|
from nucliadb.reader import API_PREFIX
|
|
36
36
|
from nucliadb.reader.api.v1.router import api as api_reader_v1
|
|
37
37
|
from nucliadb.search.api.v1.router import api as api_search_v1
|
|
@@ -79,7 +79,7 @@ HOMEPAGE_HTML = """
|
|
|
79
79
|
</ul>
|
|
80
80
|
</body>
|
|
81
81
|
</html>
|
|
82
|
-
"""
|
|
82
|
+
"""
|
|
83
83
|
|
|
84
84
|
|
|
85
85
|
def application_factory(settings: Settings) -> FastAPI:
|
|
@@ -95,13 +95,13 @@ def application_factory(settings: Settings) -> FastAPI:
|
|
|
95
95
|
backend=get_auth_backend(settings),
|
|
96
96
|
),
|
|
97
97
|
Middleware(AuditMiddleware, audit_utility_getter=get_audit),
|
|
98
|
+
Middleware(ClientErrorPayloadLoggerMiddleware),
|
|
98
99
|
]
|
|
99
100
|
if running_settings.debug:
|
|
100
101
|
middleware.append(Middleware(ProcessTimeHeaderMiddleware))
|
|
101
102
|
|
|
102
103
|
fastapi_settings = dict(
|
|
103
104
|
debug=running_settings.debug,
|
|
104
|
-
middleware=middleware,
|
|
105
105
|
lifespan=lifespan,
|
|
106
106
|
exception_handlers={
|
|
107
107
|
Exception: global_exception_handler,
|
|
@@ -122,6 +122,7 @@ def application_factory(settings: Settings) -> FastAPI:
|
|
|
122
122
|
prefix_format=f"/{API_PREFIX}/v{{major}}",
|
|
123
123
|
default_version=(1, 0),
|
|
124
124
|
enable_latest=False,
|
|
125
|
+
middleware=middleware,
|
|
125
126
|
kwargs=fastapi_settings,
|
|
126
127
|
)
|
|
127
128
|
|
nucliadb/standalone/auth.py
CHANGED
|
@@ -19,7 +19,6 @@
|
|
|
19
19
|
import base64
|
|
20
20
|
import logging
|
|
21
21
|
import time
|
|
22
|
-
from typing import Optional
|
|
23
22
|
|
|
24
23
|
import orjson
|
|
25
24
|
from jwcrypto import jwe, jwk # type: ignore
|
|
@@ -51,7 +50,7 @@ def get_mapped_roles(*, settings: Settings, data: dict[str, str]) -> list[str]:
|
|
|
51
50
|
|
|
52
51
|
async def authenticate_auth_token(
|
|
53
52
|
settings: Settings, request: HTTPConnection
|
|
54
|
-
) ->
|
|
53
|
+
) -> tuple[AuthCredentials, BaseUser] | None:
|
|
55
54
|
if "eph-token" not in request.query_params or settings.jwk_key is None:
|
|
56
55
|
return None
|
|
57
56
|
|
|
@@ -81,7 +80,7 @@ class AuthHeaderAuthenticationBackend(NucliaCloudAuthenticationBackend):
|
|
|
81
80
|
def __init__(self, settings: Settings) -> None:
|
|
82
81
|
self.settings = settings
|
|
83
82
|
|
|
84
|
-
async def authenticate(self, request: HTTPConnection) ->
|
|
83
|
+
async def authenticate(self, request: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
|
|
85
84
|
token_resp = await authenticate_auth_token(self.settings, request)
|
|
86
85
|
if token_resp is not None:
|
|
87
86
|
return token_resp
|
|
@@ -109,7 +108,7 @@ class OAuth2AuthenticationBackend(NucliaCloudAuthenticationBackend):
|
|
|
109
108
|
def __init__(self, settings: Settings) -> None:
|
|
110
109
|
self.settings = settings
|
|
111
110
|
|
|
112
|
-
async def authenticate(self, request: HTTPConnection) ->
|
|
111
|
+
async def authenticate(self, request: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
|
|
113
112
|
token_resp = await authenticate_auth_token(self.settings, request)
|
|
114
113
|
if token_resp is not None:
|
|
115
114
|
return token_resp
|
|
@@ -160,7 +159,7 @@ class BasicAuthAuthenticationBackend(NucliaCloudAuthenticationBackend):
|
|
|
160
159
|
def __init__(self, settings: Settings) -> None:
|
|
161
160
|
self.settings = settings
|
|
162
161
|
|
|
163
|
-
async def authenticate(self, request: HTTPConnection) ->
|
|
162
|
+
async def authenticate(self, request: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
|
|
164
163
|
token_resp = await authenticate_auth_token(self.settings, request)
|
|
165
164
|
if token_resp is not None:
|
|
166
165
|
return token_resp
|
|
@@ -189,7 +188,7 @@ class UpstreamNaiveAuthenticationBackend(NucliaCloudAuthenticationBackend):
|
|
|
189
188
|
user_header=settings.auth_policy_user_header,
|
|
190
189
|
)
|
|
191
190
|
|
|
192
|
-
async def authenticate(self, request: HTTPConnection) ->
|
|
191
|
+
async def authenticate(self, request: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
|
|
193
192
|
token_resp = await authenticate_auth_token(self.settings, request)
|
|
194
193
|
if token_resp is not None:
|
|
195
194
|
return token_resp
|
nucliadb/standalone/lifecycle.py
CHANGED
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
|
-
import
|
|
20
|
+
import inspect
|
|
21
21
|
from contextlib import asynccontextmanager
|
|
22
22
|
|
|
23
23
|
from fastapi import FastAPI
|
|
@@ -56,7 +56,7 @@ async def lifespan(app: FastAPI):
|
|
|
56
56
|
yield
|
|
57
57
|
|
|
58
58
|
for finalizer in SYNC_FINALIZERS:
|
|
59
|
-
if
|
|
59
|
+
if inspect.iscoroutinefunction(finalizer):
|
|
60
60
|
await finalizer()
|
|
61
61
|
else:
|
|
62
62
|
finalizer()
|
nucliadb/standalone/run.py
CHANGED
|
@@ -21,7 +21,6 @@ import asyncio
|
|
|
21
21
|
import logging
|
|
22
22
|
import os
|
|
23
23
|
import sys
|
|
24
|
-
from typing import Optional
|
|
25
24
|
|
|
26
25
|
import argdantic
|
|
27
26
|
import uvicorn # type: ignore
|
|
@@ -116,6 +115,9 @@ def run():
|
|
|
116
115
|
if nuclia_settings.nuclia_service_account:
|
|
117
116
|
settings_to_output["NUA API key"] = "Configured ✔"
|
|
118
117
|
settings_to_output["NUA API zone"] = nuclia_settings.nuclia_zone
|
|
118
|
+
settings_to_output["NUA API url"] = (
|
|
119
|
+
nuclia_settings.nuclia_public_url.format(zone=nuclia_settings.nuclia_zone) + "/api"
|
|
120
|
+
)
|
|
119
121
|
|
|
120
122
|
settings_to_output_fmted = "\n".join(
|
|
121
123
|
[f"|| - {k}:{' ' * (27 - len(k))}{v}" for k, v in settings_to_output.items()]
|
|
@@ -145,9 +147,8 @@ def run():
|
|
|
145
147
|
server.run()
|
|
146
148
|
|
|
147
149
|
|
|
148
|
-
def get_latest_nucliadb() ->
|
|
149
|
-
|
|
150
|
-
return loop.run_until_complete(versions.latest_nucliadb())
|
|
150
|
+
def get_latest_nucliadb() -> str | None:
|
|
151
|
+
return asyncio.run(versions.latest_nucliadb())
|
|
151
152
|
|
|
152
153
|
|
|
153
154
|
async def run_async_nucliadb(settings: Settings) -> uvicorn.Server:
|
nucliadb/standalone/settings.py
CHANGED
|
@@ -18,7 +18,6 @@
|
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
20
|
from enum import Enum
|
|
21
|
-
from typing import Optional
|
|
22
21
|
|
|
23
22
|
import pydantic
|
|
24
23
|
|
|
@@ -44,11 +43,11 @@ class Settings(DriverSettings, StorageSettings, ExtendedStorageSettings):
|
|
|
44
43
|
# all settings here are mapped in to other env var settings used
|
|
45
44
|
# in the app. These are helper settings to make things easier to
|
|
46
45
|
# use with standalone app vs cluster app.
|
|
47
|
-
nua_api_key:
|
|
46
|
+
nua_api_key: str | None = pydantic.Field(
|
|
48
47
|
default=None,
|
|
49
|
-
description="Nuclia Understanding API Key. Read how to generate a NUA Key here: https://docs.nuclia.dev/docs/rag/advanced/understanding/intro#get-a-nua-key",
|
|
48
|
+
description="Nuclia Understanding API Key. Read how to generate a NUA Key here: https://docs.nuclia.dev/docs/rag/advanced/understanding/intro#get-a-nua-key",
|
|
50
49
|
)
|
|
51
|
-
zone:
|
|
50
|
+
zone: str | None = pydantic.Field(default=None, description="Nuclia Understanding API Zone ID")
|
|
52
51
|
http_host: str = pydantic.Field(default="0.0.0.0", description="HTTP Port")
|
|
53
52
|
http_port: int = pydantic.Field(default=8080, description="HTTP Port")
|
|
54
53
|
ingest_grpc_port: int = pydantic.Field(default=8030, description="Ingest GRPC Port")
|
|
@@ -83,7 +82,7 @@ class Settings(DriverSettings, StorageSettings, ExtendedStorageSettings):
|
|
|
83
82
|
description="Default role to assign to user that is authenticated \
|
|
84
83
|
upstream. Not used with `upstream_naive` auth policy.",
|
|
85
84
|
)
|
|
86
|
-
auth_policy_role_mapping:
|
|
85
|
+
auth_policy_role_mapping: dict[str, dict[str, list[NucliaDBRoles]]] | None = pydantic.Field(
|
|
87
86
|
default=None,
|
|
88
87
|
description="""
|
|
89
88
|
Role mapping for `upstream_auth_header`, `upstream_oauth2` and `upstream_basicauth` auth policies.
|
|
@@ -97,7 +96,7 @@ Examples:
|
|
|
97
96
|
""",
|
|
98
97
|
)
|
|
99
98
|
|
|
100
|
-
jwk_key:
|
|
99
|
+
jwk_key: str | None = pydantic.Field(
|
|
101
100
|
default=None,
|
|
102
101
|
description="JWK key used for temporary token generation and validation.",
|
|
103
102
|
)
|
nucliadb/standalone/versions.py
CHANGED
|
@@ -20,7 +20,6 @@
|
|
|
20
20
|
import enum
|
|
21
21
|
import importlib.metadata
|
|
22
22
|
import logging
|
|
23
|
-
from typing import Optional
|
|
24
23
|
|
|
25
24
|
from cachetools import TTLCache
|
|
26
25
|
|
|
@@ -45,11 +44,11 @@ def installed_nucliadb() -> str:
|
|
|
45
44
|
return get_installed_version(StandalonePackages.NUCLIADB.value)
|
|
46
45
|
|
|
47
46
|
|
|
48
|
-
async def latest_nucliadb() ->
|
|
47
|
+
async def latest_nucliadb() -> str | None:
|
|
49
48
|
return await get_latest_version(StandalonePackages.NUCLIADB.value)
|
|
50
49
|
|
|
51
50
|
|
|
52
|
-
def nucliadb_updates_available(installed: str, latest:
|
|
51
|
+
def nucliadb_updates_available(installed: str, latest: str | None) -> bool:
|
|
53
52
|
if latest is None:
|
|
54
53
|
return False
|
|
55
54
|
return is_newer_release(installed, latest)
|
|
@@ -96,7 +95,7 @@ def get_installed_version(package_name: str) -> str:
|
|
|
96
95
|
return importlib.metadata.distribution(package_name).version
|
|
97
96
|
|
|
98
97
|
|
|
99
|
-
async def get_latest_version(package: str) ->
|
|
98
|
+
async def get_latest_version(package: str) -> str | None:
|
|
100
99
|
result = CACHE.get(package, None)
|
|
101
100
|
if result is None:
|
|
102
101
|
try:
|
nucliadb/tasks/consumer.py
CHANGED
|
@@ -19,9 +19,10 @@
|
|
|
19
19
|
#
|
|
20
20
|
|
|
21
21
|
import asyncio
|
|
22
|
-
from typing import Generic
|
|
22
|
+
from typing import Generic
|
|
23
23
|
|
|
24
24
|
import nats
|
|
25
|
+
import nats.js.api
|
|
25
26
|
import pydantic
|
|
26
27
|
from nats.aio.client import Msg
|
|
27
28
|
|
|
@@ -43,8 +44,9 @@ class NatsTaskConsumer(Generic[MsgType]):
|
|
|
43
44
|
stream: NatsStream,
|
|
44
45
|
consumer: NatsConsumer,
|
|
45
46
|
callback: Callback,
|
|
46
|
-
msg_type:
|
|
47
|
-
max_concurrent_messages:
|
|
47
|
+
msg_type: type[MsgType],
|
|
48
|
+
max_concurrent_messages: int | None = None,
|
|
49
|
+
max_deliver: int | None = None,
|
|
48
50
|
):
|
|
49
51
|
self.name = name
|
|
50
52
|
self.stream = stream
|
|
@@ -52,6 +54,7 @@ class NatsTaskConsumer(Generic[MsgType]):
|
|
|
52
54
|
self.callback = callback
|
|
53
55
|
self.msg_type = msg_type
|
|
54
56
|
self.max_concurrent_messages = max_concurrent_messages
|
|
57
|
+
self.max_deliver = max_deliver
|
|
55
58
|
self.initialized = False
|
|
56
59
|
self.running_tasks: list[asyncio.Task] = []
|
|
57
60
|
self.subscription = None
|
|
@@ -71,7 +74,8 @@ class NatsTaskConsumer(Generic[MsgType]):
|
|
|
71
74
|
for task in self.running_tasks:
|
|
72
75
|
task.cancel()
|
|
73
76
|
try:
|
|
74
|
-
|
|
77
|
+
if len(self.running_tasks) > 0:
|
|
78
|
+
await asyncio.wait(self.running_tasks, timeout=5)
|
|
75
79
|
self.running_tasks.clear()
|
|
76
80
|
except asyncio.TimeoutError:
|
|
77
81
|
pass
|
|
@@ -96,6 +100,7 @@ class NatsTaskConsumer(Generic[MsgType]):
|
|
|
96
100
|
ack_wait=nats_consumer_settings.nats_ack_wait,
|
|
97
101
|
idle_heartbeat=nats_consumer_settings.nats_idle_heartbeat,
|
|
98
102
|
max_ack_pending=max_ack_pending,
|
|
103
|
+
max_deliver=self.max_deliver,
|
|
99
104
|
),
|
|
100
105
|
)
|
|
101
106
|
logger.info(
|
|
@@ -168,8 +173,6 @@ class NatsTaskConsumer(Generic[MsgType]):
|
|
|
168
173
|
},
|
|
169
174
|
)
|
|
170
175
|
await msg.ack()
|
|
171
|
-
finally:
|
|
172
|
-
return
|
|
173
176
|
|
|
174
177
|
|
|
175
178
|
def create_consumer(
|
|
@@ -177,8 +180,9 @@ def create_consumer(
|
|
|
177
180
|
stream: NatsStream,
|
|
178
181
|
consumer: NatsConsumer,
|
|
179
182
|
callback: Callback,
|
|
180
|
-
msg_type:
|
|
181
|
-
max_concurrent_messages:
|
|
183
|
+
msg_type: type[MsgType],
|
|
184
|
+
max_concurrent_messages: int | None = None,
|
|
185
|
+
max_retries: int = 100,
|
|
182
186
|
) -> NatsTaskConsumer[MsgType]:
|
|
183
187
|
"""
|
|
184
188
|
Returns a non-initialized consumer
|
|
@@ -190,4 +194,5 @@ def create_consumer(
|
|
|
190
194
|
callback=callback,
|
|
191
195
|
msg_type=msg_type,
|
|
192
196
|
max_concurrent_messages=max_concurrent_messages,
|
|
197
|
+
max_deliver=max_retries,
|
|
193
198
|
)
|
nucliadb/tasks/models.py
CHANGED
|
@@ -17,7 +17,8 @@
|
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
|
-
from
|
|
20
|
+
from collections.abc import Callable, Coroutine
|
|
21
|
+
from typing import Any, TypeVar
|
|
21
22
|
|
|
22
23
|
import pydantic
|
|
23
24
|
|
nucliadb/tasks/producer.py
CHANGED
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
|
-
from typing import Generic
|
|
20
|
+
from typing import Generic
|
|
21
21
|
|
|
22
22
|
from nucliadb.tasks.logger import logger
|
|
23
23
|
from nucliadb.tasks.models import MsgType
|
|
@@ -32,7 +32,7 @@ class NatsTaskProducer(Generic[MsgType]):
|
|
|
32
32
|
name: str,
|
|
33
33
|
stream: NatsStream,
|
|
34
34
|
producer_subject: str,
|
|
35
|
-
msg_type:
|
|
35
|
+
msg_type: type[MsgType],
|
|
36
36
|
):
|
|
37
37
|
self.name = name
|
|
38
38
|
self.stream = stream
|
|
@@ -69,7 +69,7 @@ def create_producer(
|
|
|
69
69
|
name: str,
|
|
70
70
|
stream: NatsStream,
|
|
71
71
|
producer_subject: str,
|
|
72
|
-
msg_type:
|
|
72
|
+
msg_type: type[MsgType],
|
|
73
73
|
) -> NatsTaskProducer[MsgType]:
|
|
74
74
|
"""
|
|
75
75
|
Returns a non-initialized producer.
|
nucliadb/tasks/retries.py
CHANGED
|
@@ -19,9 +19,10 @@
|
|
|
19
19
|
#
|
|
20
20
|
import functools
|
|
21
21
|
import logging
|
|
22
|
+
from collections.abc import Callable
|
|
22
23
|
from datetime import datetime, timezone
|
|
23
24
|
from enum import Enum
|
|
24
|
-
from typing import
|
|
25
|
+
from typing import cast
|
|
25
26
|
|
|
26
27
|
from pydantic import BaseModel
|
|
27
28
|
|
|
@@ -44,7 +45,7 @@ class TaskMetadata(BaseModel):
|
|
|
44
45
|
status: Status
|
|
45
46
|
retries: int = 0
|
|
46
47
|
error_messages: list[str] = []
|
|
47
|
-
last_modified:
|
|
48
|
+
last_modified: datetime | None = None
|
|
48
49
|
|
|
49
50
|
|
|
50
51
|
class TaskRetryHandler:
|
|
@@ -87,7 +88,7 @@ class TaskRetryHandler:
|
|
|
87
88
|
kbid=self.kbid, task_type=self.task_type, task_id=self.task_id
|
|
88
89
|
)
|
|
89
90
|
|
|
90
|
-
async def get_metadata(self) ->
|
|
91
|
+
async def get_metadata(self) -> TaskMetadata | None:
|
|
91
92
|
return await _get_metadata(self.context.kv_driver, self.metadata_key)
|
|
92
93
|
|
|
93
94
|
async def set_metadata(self, metadata: TaskMetadata) -> None:
|
|
@@ -150,7 +151,7 @@ class TaskRetryHandler:
|
|
|
150
151
|
return wrapper
|
|
151
152
|
|
|
152
153
|
|
|
153
|
-
async def _get_metadata(kv_driver: Driver, metadata_key: str) ->
|
|
154
|
+
async def _get_metadata(kv_driver: Driver, metadata_key: str) -> TaskMetadata | None:
|
|
154
155
|
async with kv_driver.ro_transaction() as txn:
|
|
155
156
|
metadata = await txn.get(metadata_key)
|
|
156
157
|
if metadata is None:
|
|
@@ -173,7 +174,7 @@ async def purge_metadata(kv_driver: Driver) -> int:
|
|
|
173
174
|
return 0
|
|
174
175
|
|
|
175
176
|
total_purged = 0
|
|
176
|
-
start:
|
|
177
|
+
start: str | None = ""
|
|
177
178
|
while True:
|
|
178
179
|
start, purged = await purge_batch(kv_driver, start)
|
|
179
180
|
total_purged += purged
|
|
@@ -183,8 +184,8 @@ async def purge_metadata(kv_driver: Driver) -> int:
|
|
|
183
184
|
|
|
184
185
|
|
|
185
186
|
async def purge_batch(
|
|
186
|
-
kv_driver: PGDriver, start:
|
|
187
|
-
) -> tuple[
|
|
187
|
+
kv_driver: PGDriver, start: str | None = None, batch_size: int = 200
|
|
188
|
+
) -> tuple[str | None, int]:
|
|
188
189
|
"""
|
|
189
190
|
Returns the next start key and the number of purged records. If start is None, it means there are no more records to purge.
|
|
190
191
|
"""
|
nucliadb/train/api/utils.py
CHANGED
|
@@ -19,12 +19,10 @@
|
|
|
19
19
|
#
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
from typing import Optional
|
|
23
|
-
|
|
24
22
|
from nucliadb.train.utils import get_shard_manager
|
|
25
23
|
|
|
26
24
|
|
|
27
|
-
async def get_kb_partitions(kbid: str, prefix:
|
|
25
|
+
async def get_kb_partitions(kbid: str, prefix: str | None = None) -> list[str]:
|
|
28
26
|
shard_manager = get_shard_manager()
|
|
29
27
|
shards = await shard_manager.get_shards_by_kbid_inner(kbid=kbid)
|
|
30
28
|
valid_shards = []
|
nucliadb/train/api/v1/shards.py
CHANGED
|
@@ -18,7 +18,6 @@
|
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
20
|
import json
|
|
21
|
-
from typing import Optional
|
|
22
21
|
|
|
23
22
|
import google.protobuf.message
|
|
24
23
|
import pydantic
|
|
@@ -63,7 +62,7 @@ async def object_get_response(
|
|
|
63
62
|
)
|
|
64
63
|
|
|
65
64
|
|
|
66
|
-
async def get_trainset(request: Request) -> tuple[TrainSet,
|
|
65
|
+
async def get_trainset(request: Request) -> tuple[TrainSet, FilterExpression | None]:
|
|
67
66
|
if request.headers.get("Content-Type") == "application/json":
|
|
68
67
|
try:
|
|
69
68
|
trainset_model = TrainSetModel.model_validate(await request.json())
|
|
@@ -18,7 +18,6 @@
|
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
20
|
|
|
21
|
-
from typing import Optional
|
|
22
21
|
|
|
23
22
|
from fastapi import HTTPException, Request
|
|
24
23
|
from fastapi_versioning import version
|
|
@@ -57,7 +56,7 @@ async def get_partitions_prefix(request: Request, kbid: str, prefix: str) -> Tra
|
|
|
57
56
|
return await get_partitions(kbid, prefix=prefix)
|
|
58
57
|
|
|
59
58
|
|
|
60
|
-
async def get_partitions(kbid: str, prefix:
|
|
59
|
+
async def get_partitions(kbid: str, prefix: str | None = None) -> TrainSetPartitions:
|
|
61
60
|
try:
|
|
62
61
|
all_keys = await get_kb_partitions(kbid, prefix)
|
|
63
62
|
except ShardNotFound:
|
nucliadb/train/app.py
CHANGED
|
@@ -50,7 +50,6 @@ errors.setup_error_handling(importlib.metadata.distribution("nucliadb").version)
|
|
|
50
50
|
|
|
51
51
|
fastapi_settings = dict(
|
|
52
52
|
debug=running_settings.debug,
|
|
53
|
-
middleware=middleware,
|
|
54
53
|
lifespan=lifespan,
|
|
55
54
|
exception_handlers={
|
|
56
55
|
Exception: global_exception_handler,
|
|
@@ -71,6 +70,7 @@ application = VersionedFastAPI(
|
|
|
71
70
|
prefix_format=f"/{API_PREFIX}/v{{major}}",
|
|
72
71
|
default_version=(1, 0),
|
|
73
72
|
enable_latest=False,
|
|
73
|
+
middleware=middleware,
|
|
74
74
|
kwargs=fastapi_settings,
|
|
75
75
|
)
|
|
76
76
|
|
nucliadb/train/generator.py
CHANGED
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
|
-
from
|
|
20
|
+
from collections.abc import AsyncIterator, Callable
|
|
21
21
|
|
|
22
22
|
from fastapi import HTTPException
|
|
23
23
|
from grpc import StatusCode
|
|
@@ -53,11 +53,11 @@ from nucliadb.train.utils import get_shard_manager
|
|
|
53
53
|
from nucliadb_models.filters import FilterExpression
|
|
54
54
|
from nucliadb_protos.dataset_pb2 import TaskType, TrainSet
|
|
55
55
|
|
|
56
|
-
BatchGenerator = Callable[[str, TrainSet, str,
|
|
56
|
+
BatchGenerator = Callable[[str, TrainSet, str, FilterExpression | None], AsyncIterator[TrainBatch]]
|
|
57
57
|
|
|
58
58
|
|
|
59
59
|
async def generate_train_data(
|
|
60
|
-
kbid: str, shard: str, trainset: TrainSet, filter_expression:
|
|
60
|
+
kbid: str, shard: str, trainset: TrainSet, filter_expression: FilterExpression | None = None
|
|
61
61
|
):
|
|
62
62
|
# Get the data structure to generate data
|
|
63
63
|
shard_manager = get_shard_manager()
|
|
@@ -66,7 +66,7 @@ async def generate_train_data(
|
|
|
66
66
|
if trainset.batch_size == 0:
|
|
67
67
|
trainset.batch_size = 50
|
|
68
68
|
|
|
69
|
-
batch_generator:
|
|
69
|
+
batch_generator: BatchGenerator | None = None
|
|
70
70
|
|
|
71
71
|
if trainset.type == TaskType.FIELD_CLASSIFICATION:
|
|
72
72
|
batch_generator = field_classification_batch_generator
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
20
|
|
|
21
|
-
from
|
|
21
|
+
from collections.abc import AsyncGenerator
|
|
22
22
|
|
|
23
23
|
from nidx_protos.nodereader_pb2 import StreamRequest
|
|
24
24
|
|
|
@@ -39,7 +39,7 @@ def field_classification_batch_generator(
|
|
|
39
39
|
kbid: str,
|
|
40
40
|
trainset: TrainSet,
|
|
41
41
|
shard_replica_id: str,
|
|
42
|
-
filter_expression:
|
|
42
|
+
filter_expression: FilterExpression | None,
|
|
43
43
|
) -> AsyncGenerator[FieldClassificationBatch, None]:
|
|
44
44
|
generator = generate_field_classification_payloads(kbid, trainset, shard_replica_id)
|
|
45
45
|
batch_generator = batchify(generator, trainset.batch_size, FieldClassificationBatch)
|
|
@@ -19,7 +19,7 @@
|
|
|
19
19
|
#
|
|
20
20
|
|
|
21
21
|
import asyncio
|
|
22
|
-
from
|
|
22
|
+
from collections.abc import AsyncGenerator, AsyncIterable
|
|
23
23
|
|
|
24
24
|
from nidx_protos.nodereader_pb2 import DocumentItem, StreamRequest
|
|
25
25
|
|
|
@@ -45,7 +45,7 @@ def field_streaming_batch_generator(
|
|
|
45
45
|
kbid: str,
|
|
46
46
|
trainset: TrainSet,
|
|
47
47
|
shard_replica_id: str,
|
|
48
|
-
filter_expression:
|
|
48
|
+
filter_expression: FilterExpression | None,
|
|
49
49
|
) -> AsyncGenerator[FieldStreamingBatch, None]:
|
|
50
50
|
generator = generate_field_streaming_payloads(kbid, trainset, shard_replica_id, filter_expression)
|
|
51
51
|
batch_generator = batchify(generator, trainset.batch_size, FieldStreamingBatch)
|
|
@@ -53,7 +53,7 @@ def field_streaming_batch_generator(
|
|
|
53
53
|
|
|
54
54
|
|
|
55
55
|
async def generate_field_streaming_payloads(
|
|
56
|
-
kbid: str, trainset: TrainSet, shard_replica_id: str, filter_expression:
|
|
56
|
+
kbid: str, trainset: TrainSet, shard_replica_id: str, filter_expression: FilterExpression | None
|
|
57
57
|
) -> AsyncGenerator[FieldSplitData, None]:
|
|
58
58
|
request = StreamRequest()
|
|
59
59
|
request.shard_id.id = shard_replica_id
|
|
@@ -192,7 +192,7 @@ async def _fetch_basic(kbid: str, fsd: FieldSplitData):
|
|
|
192
192
|
fsd.basic.CopyFrom(basic)
|
|
193
193
|
|
|
194
194
|
|
|
195
|
-
async def get_field_text(kbid: str, rid: str, field: str, field_type: str) ->
|
|
195
|
+
async def get_field_text(kbid: str, rid: str, field: str, field_type: str) -> ExtractedText | None:
|
|
196
196
|
orm_resource = await get_resource_from_cache_or_db(kbid, rid)
|
|
197
197
|
|
|
198
198
|
if orm_resource is None:
|
|
@@ -208,7 +208,7 @@ async def get_field_text(kbid: str, rid: str, field: str, field_type: str) -> Op
|
|
|
208
208
|
|
|
209
209
|
async def get_field_metadata(
|
|
210
210
|
kbid: str, rid: str, field: str, field_type: str
|
|
211
|
-
) ->
|
|
211
|
+
) -> FieldComputedMetadata | None:
|
|
212
212
|
orm_resource = await get_resource_from_cache_or_db(kbid, rid)
|
|
213
213
|
|
|
214
214
|
if orm_resource is None:
|
|
@@ -222,7 +222,7 @@ async def get_field_metadata(
|
|
|
222
222
|
return field_metadata
|
|
223
223
|
|
|
224
224
|
|
|
225
|
-
async def get_field_basic(kbid: str, rid: str, field: str, field_type: str) ->
|
|
225
|
+
async def get_field_basic(kbid: str, rid: str, field: str, field_type: str) -> Basic | None:
|
|
226
226
|
orm_resource = await get_resource_from_cache_or_db(kbid, rid)
|
|
227
227
|
|
|
228
228
|
if orm_resource is None:
|