nucliadb 6.9.1.post5192__py3-none-any.whl → 6.10.0.post5705__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- migrations/0023_backfill_pg_catalog.py +2 -2
- migrations/0029_backfill_field_status.py +3 -4
- migrations/0032_remove_old_relations.py +2 -3
- migrations/0038_backfill_catalog_field_labels.py +2 -2
- migrations/0039_backfill_converation_splits_metadata.py +2 -2
- migrations/0041_reindex_conversations.py +137 -0
- migrations/pg/0010_shards_index.py +34 -0
- nucliadb/search/api/v1/resource/utils.py → migrations/pg/0011_catalog_statistics.py +5 -6
- migrations/pg/0012_catalog_statistics_undo.py +26 -0
- nucliadb/backups/create.py +2 -15
- nucliadb/backups/restore.py +4 -15
- nucliadb/backups/tasks.py +4 -1
- nucliadb/common/back_pressure/cache.py +2 -3
- nucliadb/common/back_pressure/materializer.py +7 -13
- nucliadb/common/back_pressure/settings.py +6 -6
- nucliadb/common/back_pressure/utils.py +1 -0
- nucliadb/common/cache.py +9 -9
- nucliadb/common/catalog/interface.py +12 -12
- nucliadb/common/catalog/pg.py +41 -29
- nucliadb/common/catalog/utils.py +3 -3
- nucliadb/common/cluster/manager.py +5 -4
- nucliadb/common/cluster/rebalance.py +483 -114
- nucliadb/common/cluster/rollover.py +25 -9
- nucliadb/common/cluster/settings.py +3 -8
- nucliadb/common/cluster/utils.py +34 -8
- nucliadb/common/context/__init__.py +7 -8
- nucliadb/common/context/fastapi.py +1 -2
- nucliadb/common/datamanagers/__init__.py +2 -4
- nucliadb/common/datamanagers/atomic.py +4 -2
- nucliadb/common/datamanagers/cluster.py +1 -2
- nucliadb/common/datamanagers/fields.py +3 -4
- nucliadb/common/datamanagers/kb.py +6 -6
- nucliadb/common/datamanagers/labels.py +2 -3
- nucliadb/common/datamanagers/resources.py +10 -33
- nucliadb/common/datamanagers/rollover.py +5 -7
- nucliadb/common/datamanagers/search_configurations.py +1 -2
- nucliadb/common/datamanagers/synonyms.py +1 -2
- nucliadb/common/datamanagers/utils.py +4 -4
- nucliadb/common/datamanagers/vectorsets.py +4 -4
- nucliadb/common/external_index_providers/base.py +32 -5
- nucliadb/common/external_index_providers/manager.py +4 -5
- nucliadb/common/filter_expression.py +128 -40
- nucliadb/common/http_clients/processing.py +12 -23
- nucliadb/common/ids.py +6 -4
- nucliadb/common/locking.py +1 -2
- nucliadb/common/maindb/driver.py +9 -8
- nucliadb/common/maindb/local.py +5 -5
- nucliadb/common/maindb/pg.py +9 -8
- nucliadb/common/nidx.py +3 -4
- nucliadb/export_import/datamanager.py +4 -3
- nucliadb/export_import/exporter.py +11 -19
- nucliadb/export_import/importer.py +13 -6
- nucliadb/export_import/tasks.py +2 -0
- nucliadb/export_import/utils.py +6 -18
- nucliadb/health.py +2 -2
- nucliadb/ingest/app.py +8 -8
- nucliadb/ingest/consumer/consumer.py +8 -10
- nucliadb/ingest/consumer/pull.py +3 -8
- nucliadb/ingest/consumer/service.py +3 -3
- nucliadb/ingest/consumer/utils.py +1 -1
- nucliadb/ingest/fields/base.py +28 -49
- nucliadb/ingest/fields/conversation.py +12 -12
- nucliadb/ingest/fields/exceptions.py +1 -2
- nucliadb/ingest/fields/file.py +22 -8
- nucliadb/ingest/fields/link.py +7 -7
- nucliadb/ingest/fields/text.py +2 -3
- nucliadb/ingest/orm/brain_v2.py +78 -64
- nucliadb/ingest/orm/broker_message.py +2 -4
- nucliadb/ingest/orm/entities.py +10 -209
- nucliadb/ingest/orm/index_message.py +4 -4
- nucliadb/ingest/orm/knowledgebox.py +18 -27
- nucliadb/ingest/orm/processor/auditing.py +1 -3
- nucliadb/ingest/orm/processor/data_augmentation.py +1 -2
- nucliadb/ingest/orm/processor/processor.py +27 -27
- nucliadb/ingest/orm/processor/sequence_manager.py +1 -2
- nucliadb/ingest/orm/resource.py +72 -70
- nucliadb/ingest/orm/utils.py +1 -1
- nucliadb/ingest/processing.py +17 -17
- nucliadb/ingest/serialize.py +202 -145
- nucliadb/ingest/service/writer.py +3 -109
- nucliadb/ingest/settings.py +3 -4
- nucliadb/ingest/utils.py +1 -2
- nucliadb/learning_proxy.py +11 -11
- nucliadb/metrics_exporter.py +5 -4
- nucliadb/middleware/__init__.py +82 -1
- nucliadb/migrator/datamanager.py +3 -4
- nucliadb/migrator/migrator.py +1 -2
- nucliadb/migrator/models.py +1 -2
- nucliadb/migrator/settings.py +1 -2
- nucliadb/models/internal/augment.py +614 -0
- nucliadb/models/internal/processing.py +19 -19
- nucliadb/openapi.py +2 -2
- nucliadb/purge/__init__.py +3 -8
- nucliadb/purge/orphan_shards.py +1 -2
- nucliadb/reader/__init__.py +5 -0
- nucliadb/reader/api/models.py +6 -13
- nucliadb/reader/api/v1/download.py +59 -38
- nucliadb/reader/api/v1/export_import.py +4 -4
- nucliadb/reader/api/v1/learning_config.py +24 -4
- nucliadb/reader/api/v1/resource.py +61 -9
- nucliadb/reader/api/v1/services.py +18 -14
- nucliadb/reader/app.py +3 -1
- nucliadb/reader/reader/notifications.py +1 -2
- nucliadb/search/api/v1/__init__.py +2 -0
- nucliadb/search/api/v1/ask.py +3 -4
- nucliadb/search/api/v1/augment.py +585 -0
- nucliadb/search/api/v1/catalog.py +11 -15
- nucliadb/search/api/v1/find.py +16 -22
- nucliadb/search/api/v1/hydrate.py +25 -25
- nucliadb/search/api/v1/knowledgebox.py +1 -2
- nucliadb/search/api/v1/predict_proxy.py +1 -2
- nucliadb/search/api/v1/resource/ask.py +7 -7
- nucliadb/search/api/v1/resource/ingestion_agents.py +5 -6
- nucliadb/search/api/v1/resource/search.py +9 -11
- nucliadb/search/api/v1/retrieve.py +130 -0
- nucliadb/search/api/v1/search.py +28 -32
- nucliadb/search/api/v1/suggest.py +11 -14
- nucliadb/search/api/v1/summarize.py +1 -2
- nucliadb/search/api/v1/utils.py +2 -2
- nucliadb/search/app.py +3 -2
- nucliadb/search/augmentor/__init__.py +21 -0
- nucliadb/search/augmentor/augmentor.py +232 -0
- nucliadb/search/augmentor/fields.py +704 -0
- nucliadb/search/augmentor/metrics.py +24 -0
- nucliadb/search/augmentor/paragraphs.py +334 -0
- nucliadb/search/augmentor/resources.py +238 -0
- nucliadb/search/augmentor/utils.py +33 -0
- nucliadb/search/lifecycle.py +3 -1
- nucliadb/search/predict.py +24 -17
- nucliadb/search/predict_models.py +8 -9
- nucliadb/search/requesters/utils.py +11 -10
- nucliadb/search/search/cache.py +19 -23
- nucliadb/search/search/chat/ask.py +88 -59
- nucliadb/search/search/chat/exceptions.py +3 -5
- nucliadb/search/search/chat/fetcher.py +201 -0
- nucliadb/search/search/chat/images.py +6 -4
- nucliadb/search/search/chat/old_prompt.py +1375 -0
- nucliadb/search/search/chat/parser.py +510 -0
- nucliadb/search/search/chat/prompt.py +563 -615
- nucliadb/search/search/chat/query.py +449 -36
- nucliadb/search/search/chat/rpc.py +85 -0
- nucliadb/search/search/fetch.py +3 -4
- nucliadb/search/search/filters.py +8 -11
- nucliadb/search/search/find.py +33 -31
- nucliadb/search/search/find_merge.py +124 -331
- nucliadb/search/search/graph_strategy.py +14 -12
- nucliadb/search/search/hydrator/__init__.py +3 -152
- nucliadb/search/search/hydrator/fields.py +92 -50
- nucliadb/search/search/hydrator/images.py +7 -7
- nucliadb/search/search/hydrator/paragraphs.py +42 -26
- nucliadb/search/search/hydrator/resources.py +20 -16
- nucliadb/search/search/ingestion_agents.py +5 -5
- nucliadb/search/search/merge.py +90 -94
- nucliadb/search/search/metrics.py +10 -9
- nucliadb/search/search/paragraphs.py +7 -9
- nucliadb/search/search/predict_proxy.py +13 -9
- nucliadb/search/search/query.py +14 -86
- nucliadb/search/search/query_parser/fetcher.py +51 -82
- nucliadb/search/search/query_parser/models.py +19 -20
- nucliadb/search/search/query_parser/old_filters.py +20 -19
- nucliadb/search/search/query_parser/parsers/ask.py +4 -5
- nucliadb/search/search/query_parser/parsers/catalog.py +5 -6
- nucliadb/search/search/query_parser/parsers/common.py +5 -6
- nucliadb/search/search/query_parser/parsers/find.py +6 -26
- nucliadb/search/search/query_parser/parsers/graph.py +13 -23
- nucliadb/search/search/query_parser/parsers/retrieve.py +207 -0
- nucliadb/search/search/query_parser/parsers/search.py +15 -53
- nucliadb/search/search/query_parser/parsers/unit_retrieval.py +8 -29
- nucliadb/search/search/rank_fusion.py +18 -13
- nucliadb/search/search/rerankers.py +5 -6
- nucliadb/search/search/retrieval.py +300 -0
- nucliadb/search/search/summarize.py +5 -6
- nucliadb/search/search/utils.py +3 -4
- nucliadb/search/settings.py +1 -2
- nucliadb/standalone/api_router.py +1 -1
- nucliadb/standalone/app.py +4 -3
- nucliadb/standalone/auth.py +5 -6
- nucliadb/standalone/lifecycle.py +2 -2
- nucliadb/standalone/run.py +2 -4
- nucliadb/standalone/settings.py +5 -6
- nucliadb/standalone/versions.py +3 -4
- nucliadb/tasks/consumer.py +13 -8
- nucliadb/tasks/models.py +2 -1
- nucliadb/tasks/producer.py +3 -3
- nucliadb/tasks/retries.py +8 -7
- nucliadb/train/api/utils.py +1 -3
- nucliadb/train/api/v1/shards.py +1 -2
- nucliadb/train/api/v1/trainset.py +1 -2
- nucliadb/train/app.py +1 -1
- nucliadb/train/generator.py +4 -4
- nucliadb/train/generators/field_classifier.py +2 -2
- nucliadb/train/generators/field_streaming.py +6 -6
- nucliadb/train/generators/image_classifier.py +2 -2
- nucliadb/train/generators/paragraph_classifier.py +2 -2
- nucliadb/train/generators/paragraph_streaming.py +2 -2
- nucliadb/train/generators/question_answer_streaming.py +2 -2
- nucliadb/train/generators/sentence_classifier.py +2 -2
- nucliadb/train/generators/token_classifier.py +3 -2
- nucliadb/train/generators/utils.py +6 -5
- nucliadb/train/nodes.py +3 -3
- nucliadb/train/resource.py +6 -8
- nucliadb/train/settings.py +3 -4
- nucliadb/train/types.py +11 -11
- nucliadb/train/upload.py +3 -2
- nucliadb/train/uploader.py +1 -2
- nucliadb/train/utils.py +1 -2
- nucliadb/writer/api/v1/export_import.py +4 -1
- nucliadb/writer/api/v1/field.py +7 -11
- nucliadb/writer/api/v1/knowledgebox.py +3 -4
- nucliadb/writer/api/v1/resource.py +9 -20
- nucliadb/writer/api/v1/services.py +10 -132
- nucliadb/writer/api/v1/upload.py +73 -72
- nucliadb/writer/app.py +8 -2
- nucliadb/writer/resource/basic.py +12 -15
- nucliadb/writer/resource/field.py +7 -5
- nucliadb/writer/resource/origin.py +7 -0
- nucliadb/writer/settings.py +2 -3
- nucliadb/writer/tus/__init__.py +2 -3
- nucliadb/writer/tus/azure.py +1 -3
- nucliadb/writer/tus/dm.py +3 -3
- nucliadb/writer/tus/exceptions.py +3 -4
- nucliadb/writer/tus/gcs.py +5 -6
- nucliadb/writer/tus/s3.py +2 -3
- nucliadb/writer/tus/storage.py +3 -3
- {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/METADATA +9 -10
- nucliadb-6.10.0.post5705.dist-info/RECORD +410 -0
- nucliadb/common/datamanagers/entities.py +0 -139
- nucliadb-6.9.1.post5192.dist-info/RECORD +0 -392
- {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/WHEEL +0 -0
- {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/entry_points.txt +0 -0
- {nucliadb-6.9.1.post5192.dist-info → nucliadb-6.10.0.post5705.dist-info}/top_level.txt +0 -0
nucliadb/standalone/auth.py
CHANGED
|
@@ -19,7 +19,6 @@
|
|
|
19
19
|
import base64
|
|
20
20
|
import logging
|
|
21
21
|
import time
|
|
22
|
-
from typing import Optional
|
|
23
22
|
|
|
24
23
|
import orjson
|
|
25
24
|
from jwcrypto import jwe, jwk # type: ignore
|
|
@@ -51,7 +50,7 @@ def get_mapped_roles(*, settings: Settings, data: dict[str, str]) -> list[str]:
|
|
|
51
50
|
|
|
52
51
|
async def authenticate_auth_token(
|
|
53
52
|
settings: Settings, request: HTTPConnection
|
|
54
|
-
) ->
|
|
53
|
+
) -> tuple[AuthCredentials, BaseUser] | None:
|
|
55
54
|
if "eph-token" not in request.query_params or settings.jwk_key is None:
|
|
56
55
|
return None
|
|
57
56
|
|
|
@@ -81,7 +80,7 @@ class AuthHeaderAuthenticationBackend(NucliaCloudAuthenticationBackend):
|
|
|
81
80
|
def __init__(self, settings: Settings) -> None:
|
|
82
81
|
self.settings = settings
|
|
83
82
|
|
|
84
|
-
async def authenticate(self, request: HTTPConnection) ->
|
|
83
|
+
async def authenticate(self, request: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
|
|
85
84
|
token_resp = await authenticate_auth_token(self.settings, request)
|
|
86
85
|
if token_resp is not None:
|
|
87
86
|
return token_resp
|
|
@@ -109,7 +108,7 @@ class OAuth2AuthenticationBackend(NucliaCloudAuthenticationBackend):
|
|
|
109
108
|
def __init__(self, settings: Settings) -> None:
|
|
110
109
|
self.settings = settings
|
|
111
110
|
|
|
112
|
-
async def authenticate(self, request: HTTPConnection) ->
|
|
111
|
+
async def authenticate(self, request: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
|
|
113
112
|
token_resp = await authenticate_auth_token(self.settings, request)
|
|
114
113
|
if token_resp is not None:
|
|
115
114
|
return token_resp
|
|
@@ -160,7 +159,7 @@ class BasicAuthAuthenticationBackend(NucliaCloudAuthenticationBackend):
|
|
|
160
159
|
def __init__(self, settings: Settings) -> None:
|
|
161
160
|
self.settings = settings
|
|
162
161
|
|
|
163
|
-
async def authenticate(self, request: HTTPConnection) ->
|
|
162
|
+
async def authenticate(self, request: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
|
|
164
163
|
token_resp = await authenticate_auth_token(self.settings, request)
|
|
165
164
|
if token_resp is not None:
|
|
166
165
|
return token_resp
|
|
@@ -189,7 +188,7 @@ class UpstreamNaiveAuthenticationBackend(NucliaCloudAuthenticationBackend):
|
|
|
189
188
|
user_header=settings.auth_policy_user_header,
|
|
190
189
|
)
|
|
191
190
|
|
|
192
|
-
async def authenticate(self, request: HTTPConnection) ->
|
|
191
|
+
async def authenticate(self, request: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
|
|
193
192
|
token_resp = await authenticate_auth_token(self.settings, request)
|
|
194
193
|
if token_resp is not None:
|
|
195
194
|
return token_resp
|
nucliadb/standalone/lifecycle.py
CHANGED
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
|
-
import
|
|
20
|
+
import inspect
|
|
21
21
|
from contextlib import asynccontextmanager
|
|
22
22
|
|
|
23
23
|
from fastapi import FastAPI
|
|
@@ -56,7 +56,7 @@ async def lifespan(app: FastAPI):
|
|
|
56
56
|
yield
|
|
57
57
|
|
|
58
58
|
for finalizer in SYNC_FINALIZERS:
|
|
59
|
-
if
|
|
59
|
+
if inspect.iscoroutinefunction(finalizer):
|
|
60
60
|
await finalizer()
|
|
61
61
|
else:
|
|
62
62
|
finalizer()
|
nucliadb/standalone/run.py
CHANGED
|
@@ -21,7 +21,6 @@ import asyncio
|
|
|
21
21
|
import logging
|
|
22
22
|
import os
|
|
23
23
|
import sys
|
|
24
|
-
from typing import Optional
|
|
25
24
|
|
|
26
25
|
import argdantic
|
|
27
26
|
import uvicorn # type: ignore
|
|
@@ -148,9 +147,8 @@ def run():
|
|
|
148
147
|
server.run()
|
|
149
148
|
|
|
150
149
|
|
|
151
|
-
def get_latest_nucliadb() ->
|
|
152
|
-
|
|
153
|
-
return loop.run_until_complete(versions.latest_nucliadb())
|
|
150
|
+
def get_latest_nucliadb() -> str | None:
|
|
151
|
+
return asyncio.run(versions.latest_nucliadb())
|
|
154
152
|
|
|
155
153
|
|
|
156
154
|
async def run_async_nucliadb(settings: Settings) -> uvicorn.Server:
|
nucliadb/standalone/settings.py
CHANGED
|
@@ -18,7 +18,6 @@
|
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
20
|
from enum import Enum
|
|
21
|
-
from typing import Optional
|
|
22
21
|
|
|
23
22
|
import pydantic
|
|
24
23
|
|
|
@@ -44,11 +43,11 @@ class Settings(DriverSettings, StorageSettings, ExtendedStorageSettings):
|
|
|
44
43
|
# all settings here are mapped in to other env var settings used
|
|
45
44
|
# in the app. These are helper settings to make things easier to
|
|
46
45
|
# use with standalone app vs cluster app.
|
|
47
|
-
nua_api_key:
|
|
46
|
+
nua_api_key: str | None = pydantic.Field(
|
|
48
47
|
default=None,
|
|
49
|
-
description="Nuclia Understanding API Key. Read how to generate a NUA Key here: https://docs.nuclia.dev/docs/rag/advanced/understanding/intro#get-a-nua-key",
|
|
48
|
+
description="Nuclia Understanding API Key. Read how to generate a NUA Key here: https://docs.nuclia.dev/docs/rag/advanced/understanding/intro#get-a-nua-key",
|
|
50
49
|
)
|
|
51
|
-
zone:
|
|
50
|
+
zone: str | None = pydantic.Field(default=None, description="Nuclia Understanding API Zone ID")
|
|
52
51
|
http_host: str = pydantic.Field(default="0.0.0.0", description="HTTP Port")
|
|
53
52
|
http_port: int = pydantic.Field(default=8080, description="HTTP Port")
|
|
54
53
|
ingest_grpc_port: int = pydantic.Field(default=8030, description="Ingest GRPC Port")
|
|
@@ -83,7 +82,7 @@ class Settings(DriverSettings, StorageSettings, ExtendedStorageSettings):
|
|
|
83
82
|
description="Default role to assign to user that is authenticated \
|
|
84
83
|
upstream. Not used with `upstream_naive` auth policy.",
|
|
85
84
|
)
|
|
86
|
-
auth_policy_role_mapping:
|
|
85
|
+
auth_policy_role_mapping: dict[str, dict[str, list[NucliaDBRoles]]] | None = pydantic.Field(
|
|
87
86
|
default=None,
|
|
88
87
|
description="""
|
|
89
88
|
Role mapping for `upstream_auth_header`, `upstream_oauth2` and `upstream_basicauth` auth policies.
|
|
@@ -97,7 +96,7 @@ Examples:
|
|
|
97
96
|
""",
|
|
98
97
|
)
|
|
99
98
|
|
|
100
|
-
jwk_key:
|
|
99
|
+
jwk_key: str | None = pydantic.Field(
|
|
101
100
|
default=None,
|
|
102
101
|
description="JWK key used for temporary token generation and validation.",
|
|
103
102
|
)
|
nucliadb/standalone/versions.py
CHANGED
|
@@ -20,7 +20,6 @@
|
|
|
20
20
|
import enum
|
|
21
21
|
import importlib.metadata
|
|
22
22
|
import logging
|
|
23
|
-
from typing import Optional
|
|
24
23
|
|
|
25
24
|
from cachetools import TTLCache
|
|
26
25
|
|
|
@@ -45,11 +44,11 @@ def installed_nucliadb() -> str:
|
|
|
45
44
|
return get_installed_version(StandalonePackages.NUCLIADB.value)
|
|
46
45
|
|
|
47
46
|
|
|
48
|
-
async def latest_nucliadb() ->
|
|
47
|
+
async def latest_nucliadb() -> str | None:
|
|
49
48
|
return await get_latest_version(StandalonePackages.NUCLIADB.value)
|
|
50
49
|
|
|
51
50
|
|
|
52
|
-
def nucliadb_updates_available(installed: str, latest:
|
|
51
|
+
def nucliadb_updates_available(installed: str, latest: str | None) -> bool:
|
|
53
52
|
if latest is None:
|
|
54
53
|
return False
|
|
55
54
|
return is_newer_release(installed, latest)
|
|
@@ -96,7 +95,7 @@ def get_installed_version(package_name: str) -> str:
|
|
|
96
95
|
return importlib.metadata.distribution(package_name).version
|
|
97
96
|
|
|
98
97
|
|
|
99
|
-
async def get_latest_version(package: str) ->
|
|
98
|
+
async def get_latest_version(package: str) -> str | None:
|
|
100
99
|
result = CACHE.get(package, None)
|
|
101
100
|
if result is None:
|
|
102
101
|
try:
|
nucliadb/tasks/consumer.py
CHANGED
|
@@ -19,9 +19,10 @@
|
|
|
19
19
|
#
|
|
20
20
|
|
|
21
21
|
import asyncio
|
|
22
|
-
from typing import Generic
|
|
22
|
+
from typing import Generic
|
|
23
23
|
|
|
24
24
|
import nats
|
|
25
|
+
import nats.js.api
|
|
25
26
|
import pydantic
|
|
26
27
|
from nats.aio.client import Msg
|
|
27
28
|
|
|
@@ -43,8 +44,9 @@ class NatsTaskConsumer(Generic[MsgType]):
|
|
|
43
44
|
stream: NatsStream,
|
|
44
45
|
consumer: NatsConsumer,
|
|
45
46
|
callback: Callback,
|
|
46
|
-
msg_type:
|
|
47
|
-
max_concurrent_messages:
|
|
47
|
+
msg_type: type[MsgType],
|
|
48
|
+
max_concurrent_messages: int | None = None,
|
|
49
|
+
max_deliver: int | None = None,
|
|
48
50
|
):
|
|
49
51
|
self.name = name
|
|
50
52
|
self.stream = stream
|
|
@@ -52,6 +54,7 @@ class NatsTaskConsumer(Generic[MsgType]):
|
|
|
52
54
|
self.callback = callback
|
|
53
55
|
self.msg_type = msg_type
|
|
54
56
|
self.max_concurrent_messages = max_concurrent_messages
|
|
57
|
+
self.max_deliver = max_deliver
|
|
55
58
|
self.initialized = False
|
|
56
59
|
self.running_tasks: list[asyncio.Task] = []
|
|
57
60
|
self.subscription = None
|
|
@@ -71,7 +74,8 @@ class NatsTaskConsumer(Generic[MsgType]):
|
|
|
71
74
|
for task in self.running_tasks:
|
|
72
75
|
task.cancel()
|
|
73
76
|
try:
|
|
74
|
-
|
|
77
|
+
if len(self.running_tasks) > 0:
|
|
78
|
+
await asyncio.wait(self.running_tasks, timeout=5)
|
|
75
79
|
self.running_tasks.clear()
|
|
76
80
|
except asyncio.TimeoutError:
|
|
77
81
|
pass
|
|
@@ -96,6 +100,7 @@ class NatsTaskConsumer(Generic[MsgType]):
|
|
|
96
100
|
ack_wait=nats_consumer_settings.nats_ack_wait,
|
|
97
101
|
idle_heartbeat=nats_consumer_settings.nats_idle_heartbeat,
|
|
98
102
|
max_ack_pending=max_ack_pending,
|
|
103
|
+
max_deliver=self.max_deliver,
|
|
99
104
|
),
|
|
100
105
|
)
|
|
101
106
|
logger.info(
|
|
@@ -168,8 +173,6 @@ class NatsTaskConsumer(Generic[MsgType]):
|
|
|
168
173
|
},
|
|
169
174
|
)
|
|
170
175
|
await msg.ack()
|
|
171
|
-
finally:
|
|
172
|
-
return
|
|
173
176
|
|
|
174
177
|
|
|
175
178
|
def create_consumer(
|
|
@@ -177,8 +180,9 @@ def create_consumer(
|
|
|
177
180
|
stream: NatsStream,
|
|
178
181
|
consumer: NatsConsumer,
|
|
179
182
|
callback: Callback,
|
|
180
|
-
msg_type:
|
|
181
|
-
max_concurrent_messages:
|
|
183
|
+
msg_type: type[MsgType],
|
|
184
|
+
max_concurrent_messages: int | None = None,
|
|
185
|
+
max_retries: int = 100,
|
|
182
186
|
) -> NatsTaskConsumer[MsgType]:
|
|
183
187
|
"""
|
|
184
188
|
Returns a non-initialized consumer
|
|
@@ -190,4 +194,5 @@ def create_consumer(
|
|
|
190
194
|
callback=callback,
|
|
191
195
|
msg_type=msg_type,
|
|
192
196
|
max_concurrent_messages=max_concurrent_messages,
|
|
197
|
+
max_deliver=max_retries,
|
|
193
198
|
)
|
nucliadb/tasks/models.py
CHANGED
|
@@ -17,7 +17,8 @@
|
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
|
-
from
|
|
20
|
+
from collections.abc import Callable, Coroutine
|
|
21
|
+
from typing import Any, TypeVar
|
|
21
22
|
|
|
22
23
|
import pydantic
|
|
23
24
|
|
nucliadb/tasks/producer.py
CHANGED
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
|
-
from typing import Generic
|
|
20
|
+
from typing import Generic
|
|
21
21
|
|
|
22
22
|
from nucliadb.tasks.logger import logger
|
|
23
23
|
from nucliadb.tasks.models import MsgType
|
|
@@ -32,7 +32,7 @@ class NatsTaskProducer(Generic[MsgType]):
|
|
|
32
32
|
name: str,
|
|
33
33
|
stream: NatsStream,
|
|
34
34
|
producer_subject: str,
|
|
35
|
-
msg_type:
|
|
35
|
+
msg_type: type[MsgType],
|
|
36
36
|
):
|
|
37
37
|
self.name = name
|
|
38
38
|
self.stream = stream
|
|
@@ -69,7 +69,7 @@ def create_producer(
|
|
|
69
69
|
name: str,
|
|
70
70
|
stream: NatsStream,
|
|
71
71
|
producer_subject: str,
|
|
72
|
-
msg_type:
|
|
72
|
+
msg_type: type[MsgType],
|
|
73
73
|
) -> NatsTaskProducer[MsgType]:
|
|
74
74
|
"""
|
|
75
75
|
Returns a non-initialized producer.
|
nucliadb/tasks/retries.py
CHANGED
|
@@ -19,9 +19,10 @@
|
|
|
19
19
|
#
|
|
20
20
|
import functools
|
|
21
21
|
import logging
|
|
22
|
+
from collections.abc import Callable
|
|
22
23
|
from datetime import datetime, timezone
|
|
23
24
|
from enum import Enum
|
|
24
|
-
from typing import
|
|
25
|
+
from typing import cast
|
|
25
26
|
|
|
26
27
|
from pydantic import BaseModel
|
|
27
28
|
|
|
@@ -44,7 +45,7 @@ class TaskMetadata(BaseModel):
|
|
|
44
45
|
status: Status
|
|
45
46
|
retries: int = 0
|
|
46
47
|
error_messages: list[str] = []
|
|
47
|
-
last_modified:
|
|
48
|
+
last_modified: datetime | None = None
|
|
48
49
|
|
|
49
50
|
|
|
50
51
|
class TaskRetryHandler:
|
|
@@ -87,7 +88,7 @@ class TaskRetryHandler:
|
|
|
87
88
|
kbid=self.kbid, task_type=self.task_type, task_id=self.task_id
|
|
88
89
|
)
|
|
89
90
|
|
|
90
|
-
async def get_metadata(self) ->
|
|
91
|
+
async def get_metadata(self) -> TaskMetadata | None:
|
|
91
92
|
return await _get_metadata(self.context.kv_driver, self.metadata_key)
|
|
92
93
|
|
|
93
94
|
async def set_metadata(self, metadata: TaskMetadata) -> None:
|
|
@@ -150,7 +151,7 @@ class TaskRetryHandler:
|
|
|
150
151
|
return wrapper
|
|
151
152
|
|
|
152
153
|
|
|
153
|
-
async def _get_metadata(kv_driver: Driver, metadata_key: str) ->
|
|
154
|
+
async def _get_metadata(kv_driver: Driver, metadata_key: str) -> TaskMetadata | None:
|
|
154
155
|
async with kv_driver.ro_transaction() as txn:
|
|
155
156
|
metadata = await txn.get(metadata_key)
|
|
156
157
|
if metadata is None:
|
|
@@ -173,7 +174,7 @@ async def purge_metadata(kv_driver: Driver) -> int:
|
|
|
173
174
|
return 0
|
|
174
175
|
|
|
175
176
|
total_purged = 0
|
|
176
|
-
start:
|
|
177
|
+
start: str | None = ""
|
|
177
178
|
while True:
|
|
178
179
|
start, purged = await purge_batch(kv_driver, start)
|
|
179
180
|
total_purged += purged
|
|
@@ -183,8 +184,8 @@ async def purge_metadata(kv_driver: Driver) -> int:
|
|
|
183
184
|
|
|
184
185
|
|
|
185
186
|
async def purge_batch(
|
|
186
|
-
kv_driver: PGDriver, start:
|
|
187
|
-
) -> tuple[
|
|
187
|
+
kv_driver: PGDriver, start: str | None = None, batch_size: int = 200
|
|
188
|
+
) -> tuple[str | None, int]:
|
|
188
189
|
"""
|
|
189
190
|
Returns the next start key and the number of purged records. If start is None, it means there are no more records to purge.
|
|
190
191
|
"""
|
nucliadb/train/api/utils.py
CHANGED
|
@@ -19,12 +19,10 @@
|
|
|
19
19
|
#
|
|
20
20
|
|
|
21
21
|
|
|
22
|
-
from typing import Optional
|
|
23
|
-
|
|
24
22
|
from nucliadb.train.utils import get_shard_manager
|
|
25
23
|
|
|
26
24
|
|
|
27
|
-
async def get_kb_partitions(kbid: str, prefix:
|
|
25
|
+
async def get_kb_partitions(kbid: str, prefix: str | None = None) -> list[str]:
|
|
28
26
|
shard_manager = get_shard_manager()
|
|
29
27
|
shards = await shard_manager.get_shards_by_kbid_inner(kbid=kbid)
|
|
30
28
|
valid_shards = []
|
nucliadb/train/api/v1/shards.py
CHANGED
|
@@ -18,7 +18,6 @@
|
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
20
|
import json
|
|
21
|
-
from typing import Optional
|
|
22
21
|
|
|
23
22
|
import google.protobuf.message
|
|
24
23
|
import pydantic
|
|
@@ -63,7 +62,7 @@ async def object_get_response(
|
|
|
63
62
|
)
|
|
64
63
|
|
|
65
64
|
|
|
66
|
-
async def get_trainset(request: Request) -> tuple[TrainSet,
|
|
65
|
+
async def get_trainset(request: Request) -> tuple[TrainSet, FilterExpression | None]:
|
|
67
66
|
if request.headers.get("Content-Type") == "application/json":
|
|
68
67
|
try:
|
|
69
68
|
trainset_model = TrainSetModel.model_validate(await request.json())
|
|
@@ -18,7 +18,6 @@
|
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
20
|
|
|
21
|
-
from typing import Optional
|
|
22
21
|
|
|
23
22
|
from fastapi import HTTPException, Request
|
|
24
23
|
from fastapi_versioning import version
|
|
@@ -57,7 +56,7 @@ async def get_partitions_prefix(request: Request, kbid: str, prefix: str) -> Tra
|
|
|
57
56
|
return await get_partitions(kbid, prefix=prefix)
|
|
58
57
|
|
|
59
58
|
|
|
60
|
-
async def get_partitions(kbid: str, prefix:
|
|
59
|
+
async def get_partitions(kbid: str, prefix: str | None = None) -> TrainSetPartitions:
|
|
61
60
|
try:
|
|
62
61
|
all_keys = await get_kb_partitions(kbid, prefix)
|
|
63
62
|
except ShardNotFound:
|
nucliadb/train/app.py
CHANGED
|
@@ -50,7 +50,6 @@ errors.setup_error_handling(importlib.metadata.distribution("nucliadb").version)
|
|
|
50
50
|
|
|
51
51
|
fastapi_settings = dict(
|
|
52
52
|
debug=running_settings.debug,
|
|
53
|
-
middleware=middleware,
|
|
54
53
|
lifespan=lifespan,
|
|
55
54
|
exception_handlers={
|
|
56
55
|
Exception: global_exception_handler,
|
|
@@ -71,6 +70,7 @@ application = VersionedFastAPI(
|
|
|
71
70
|
prefix_format=f"/{API_PREFIX}/v{{major}}",
|
|
72
71
|
default_version=(1, 0),
|
|
73
72
|
enable_latest=False,
|
|
73
|
+
middleware=middleware,
|
|
74
74
|
kwargs=fastapi_settings,
|
|
75
75
|
)
|
|
76
76
|
|
nucliadb/train/generator.py
CHANGED
|
@@ -17,7 +17,7 @@
|
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
|
-
from
|
|
20
|
+
from collections.abc import AsyncIterator, Callable
|
|
21
21
|
|
|
22
22
|
from fastapi import HTTPException
|
|
23
23
|
from grpc import StatusCode
|
|
@@ -53,11 +53,11 @@ from nucliadb.train.utils import get_shard_manager
|
|
|
53
53
|
from nucliadb_models.filters import FilterExpression
|
|
54
54
|
from nucliadb_protos.dataset_pb2 import TaskType, TrainSet
|
|
55
55
|
|
|
56
|
-
BatchGenerator = Callable[[str, TrainSet, str,
|
|
56
|
+
BatchGenerator = Callable[[str, TrainSet, str, FilterExpression | None], AsyncIterator[TrainBatch]]
|
|
57
57
|
|
|
58
58
|
|
|
59
59
|
async def generate_train_data(
|
|
60
|
-
kbid: str, shard: str, trainset: TrainSet, filter_expression:
|
|
60
|
+
kbid: str, shard: str, trainset: TrainSet, filter_expression: FilterExpression | None = None
|
|
61
61
|
):
|
|
62
62
|
# Get the data structure to generate data
|
|
63
63
|
shard_manager = get_shard_manager()
|
|
@@ -66,7 +66,7 @@ async def generate_train_data(
|
|
|
66
66
|
if trainset.batch_size == 0:
|
|
67
67
|
trainset.batch_size = 50
|
|
68
68
|
|
|
69
|
-
batch_generator:
|
|
69
|
+
batch_generator: BatchGenerator | None = None
|
|
70
70
|
|
|
71
71
|
if trainset.type == TaskType.FIELD_CLASSIFICATION:
|
|
72
72
|
batch_generator = field_classification_batch_generator
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
20
|
|
|
21
|
-
from
|
|
21
|
+
from collections.abc import AsyncGenerator
|
|
22
22
|
|
|
23
23
|
from nidx_protos.nodereader_pb2 import StreamRequest
|
|
24
24
|
|
|
@@ -39,7 +39,7 @@ def field_classification_batch_generator(
|
|
|
39
39
|
kbid: str,
|
|
40
40
|
trainset: TrainSet,
|
|
41
41
|
shard_replica_id: str,
|
|
42
|
-
filter_expression:
|
|
42
|
+
filter_expression: FilterExpression | None,
|
|
43
43
|
) -> AsyncGenerator[FieldClassificationBatch, None]:
|
|
44
44
|
generator = generate_field_classification_payloads(kbid, trainset, shard_replica_id)
|
|
45
45
|
batch_generator = batchify(generator, trainset.batch_size, FieldClassificationBatch)
|
|
@@ -19,7 +19,7 @@
|
|
|
19
19
|
#
|
|
20
20
|
|
|
21
21
|
import asyncio
|
|
22
|
-
from
|
|
22
|
+
from collections.abc import AsyncGenerator, AsyncIterable
|
|
23
23
|
|
|
24
24
|
from nidx_protos.nodereader_pb2 import DocumentItem, StreamRequest
|
|
25
25
|
|
|
@@ -45,7 +45,7 @@ def field_streaming_batch_generator(
|
|
|
45
45
|
kbid: str,
|
|
46
46
|
trainset: TrainSet,
|
|
47
47
|
shard_replica_id: str,
|
|
48
|
-
filter_expression:
|
|
48
|
+
filter_expression: FilterExpression | None,
|
|
49
49
|
) -> AsyncGenerator[FieldStreamingBatch, None]:
|
|
50
50
|
generator = generate_field_streaming_payloads(kbid, trainset, shard_replica_id, filter_expression)
|
|
51
51
|
batch_generator = batchify(generator, trainset.batch_size, FieldStreamingBatch)
|
|
@@ -53,7 +53,7 @@ def field_streaming_batch_generator(
|
|
|
53
53
|
|
|
54
54
|
|
|
55
55
|
async def generate_field_streaming_payloads(
|
|
56
|
-
kbid: str, trainset: TrainSet, shard_replica_id: str, filter_expression:
|
|
56
|
+
kbid: str, trainset: TrainSet, shard_replica_id: str, filter_expression: FilterExpression | None
|
|
57
57
|
) -> AsyncGenerator[FieldSplitData, None]:
|
|
58
58
|
request = StreamRequest()
|
|
59
59
|
request.shard_id.id = shard_replica_id
|
|
@@ -192,7 +192,7 @@ async def _fetch_basic(kbid: str, fsd: FieldSplitData):
|
|
|
192
192
|
fsd.basic.CopyFrom(basic)
|
|
193
193
|
|
|
194
194
|
|
|
195
|
-
async def get_field_text(kbid: str, rid: str, field: str, field_type: str) ->
|
|
195
|
+
async def get_field_text(kbid: str, rid: str, field: str, field_type: str) -> ExtractedText | None:
|
|
196
196
|
orm_resource = await get_resource_from_cache_or_db(kbid, rid)
|
|
197
197
|
|
|
198
198
|
if orm_resource is None:
|
|
@@ -208,7 +208,7 @@ async def get_field_text(kbid: str, rid: str, field: str, field_type: str) -> Op
|
|
|
208
208
|
|
|
209
209
|
async def get_field_metadata(
|
|
210
210
|
kbid: str, rid: str, field: str, field_type: str
|
|
211
|
-
) ->
|
|
211
|
+
) -> FieldComputedMetadata | None:
|
|
212
212
|
orm_resource = await get_resource_from_cache_or_db(kbid, rid)
|
|
213
213
|
|
|
214
214
|
if orm_resource is None:
|
|
@@ -222,7 +222,7 @@ async def get_field_metadata(
|
|
|
222
222
|
return field_metadata
|
|
223
223
|
|
|
224
224
|
|
|
225
|
-
async def get_field_basic(kbid: str, rid: str, field: str, field_type: str) ->
|
|
225
|
+
async def get_field_basic(kbid: str, rid: str, field: str, field_type: str) -> Basic | None:
|
|
226
226
|
orm_resource = await get_resource_from_cache_or_db(kbid, rid)
|
|
227
227
|
|
|
228
228
|
if orm_resource is None:
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
20
|
|
|
21
|
-
from
|
|
21
|
+
from collections.abc import AsyncGenerator
|
|
22
22
|
|
|
23
23
|
from nucliadb.train.generators.utils import batchify
|
|
24
24
|
from nucliadb_models.filters import FilterExpression
|
|
@@ -33,7 +33,7 @@ def image_classification_batch_generator(
|
|
|
33
33
|
kbid: str,
|
|
34
34
|
trainset: TrainSet,
|
|
35
35
|
shard_replica_id: str,
|
|
36
|
-
filter_expression:
|
|
36
|
+
filter_expression: FilterExpression | None,
|
|
37
37
|
) -> AsyncGenerator[ImageClassificationBatch, None]:
|
|
38
38
|
generator = generate_image_classification_payloads(kbid, trainset, shard_replica_id)
|
|
39
39
|
batch_generator = batchify(generator, trainset.batch_size, ImageClassificationBatch)
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
20
|
|
|
21
|
-
from
|
|
21
|
+
from collections.abc import AsyncGenerator
|
|
22
22
|
|
|
23
23
|
from fastapi import HTTPException
|
|
24
24
|
from nidx_protos.nodereader_pb2 import StreamRequest
|
|
@@ -38,7 +38,7 @@ def paragraph_classification_batch_generator(
|
|
|
38
38
|
kbid: str,
|
|
39
39
|
trainset: TrainSet,
|
|
40
40
|
shard_replica_id: str,
|
|
41
|
-
filter_expression:
|
|
41
|
+
filter_expression: FilterExpression | None,
|
|
42
42
|
) -> AsyncGenerator[ParagraphClassificationBatch, None]:
|
|
43
43
|
if len(trainset.filter.labels) != 1:
|
|
44
44
|
raise HTTPException(
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
20
|
|
|
21
|
-
from
|
|
21
|
+
from collections.abc import AsyncGenerator
|
|
22
22
|
|
|
23
23
|
from nidx_protos.nodereader_pb2 import StreamRequest
|
|
24
24
|
|
|
@@ -38,7 +38,7 @@ def paragraph_streaming_batch_generator(
|
|
|
38
38
|
kbid: str,
|
|
39
39
|
trainset: TrainSet,
|
|
40
40
|
shard_replica_id: str,
|
|
41
|
-
filter_expression:
|
|
41
|
+
filter_expression: FilterExpression | None,
|
|
42
42
|
) -> AsyncGenerator[ParagraphStreamingBatch, None]:
|
|
43
43
|
generator = generate_paragraph_streaming_payloads(kbid, trainset, shard_replica_id)
|
|
44
44
|
batch_generator = batchify(generator, trainset.batch_size, ParagraphStreamingBatch)
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
20
|
|
|
21
|
-
from
|
|
21
|
+
from collections.abc import AsyncGenerator
|
|
22
22
|
|
|
23
23
|
from nidx_protos.nodereader_pb2 import StreamRequest
|
|
24
24
|
|
|
@@ -47,7 +47,7 @@ def question_answer_batch_generator(
|
|
|
47
47
|
kbid: str,
|
|
48
48
|
trainset: TrainSet,
|
|
49
49
|
shard_replica_id: str,
|
|
50
|
-
filter_expression:
|
|
50
|
+
filter_expression: FilterExpression | None,
|
|
51
51
|
) -> AsyncGenerator[QuestionAnswerStreamingBatch, None]:
|
|
52
52
|
generator = generate_question_answer_streaming_payloads(kbid, trainset, shard_replica_id)
|
|
53
53
|
batch_generator = batchify(generator, trainset.batch_size, QuestionAnswerStreamingBatch)
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
20
|
|
|
21
|
-
from
|
|
21
|
+
from collections.abc import AsyncGenerator
|
|
22
22
|
|
|
23
23
|
from fastapi import HTTPException
|
|
24
24
|
from nidx_protos.nodereader_pb2 import StreamRequest
|
|
@@ -40,7 +40,7 @@ def sentence_classification_batch_generator(
|
|
|
40
40
|
kbid: str,
|
|
41
41
|
trainset: TrainSet,
|
|
42
42
|
shard_replica_id: str,
|
|
43
|
-
filter_expression:
|
|
43
|
+
filter_expression: FilterExpression | None,
|
|
44
44
|
) -> AsyncGenerator[SentenceClassificationBatch, None]:
|
|
45
45
|
if len(trainset.filter.labels) == 0:
|
|
46
46
|
raise HTTPException(
|
|
@@ -19,7 +19,8 @@
|
|
|
19
19
|
#
|
|
20
20
|
|
|
21
21
|
from collections import OrderedDict
|
|
22
|
-
from
|
|
22
|
+
from collections.abc import AsyncGenerator
|
|
23
|
+
from typing import cast
|
|
23
24
|
|
|
24
25
|
from nidx_protos.nodereader_pb2 import StreamFilter, StreamRequest
|
|
25
26
|
|
|
@@ -43,7 +44,7 @@ def token_classification_batch_generator(
|
|
|
43
44
|
kbid: str,
|
|
44
45
|
trainset: TrainSet,
|
|
45
46
|
shard_replica_id: str,
|
|
46
|
-
filter_expression:
|
|
47
|
+
filter_expression: FilterExpression | None,
|
|
47
48
|
) -> AsyncGenerator[TokenClassificationBatch, None]:
|
|
48
49
|
generator = generate_token_classification_payloads(kbid, trainset, shard_replica_id)
|
|
49
50
|
batch_generator = batchify(generator, trainset.batch_size, TokenClassificationBatch)
|
|
@@ -18,7 +18,8 @@
|
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
20
|
|
|
21
|
-
from
|
|
21
|
+
from collections.abc import AsyncGenerator, AsyncIterator
|
|
22
|
+
from typing import Any
|
|
22
23
|
|
|
23
24
|
from nucliadb.common.cache import get_resource_cache
|
|
24
25
|
from nucliadb.common.ids import FIELD_TYPE_STR_TO_PB
|
|
@@ -30,16 +31,16 @@ from nucliadb.train.types import T
|
|
|
30
31
|
from nucliadb_utils.utilities import get_storage
|
|
31
32
|
|
|
32
33
|
|
|
33
|
-
async def get_resource_from_cache_or_db(kbid: str, uuid: str) ->
|
|
34
|
+
async def get_resource_from_cache_or_db(kbid: str, uuid: str) -> ResourceORM | None:
|
|
34
35
|
resource_cache = get_resource_cache()
|
|
35
36
|
if resource_cache is None:
|
|
36
|
-
return await _get_resource_from_db(kbid, uuid)
|
|
37
37
|
logger.warning("Resource cache is not set")
|
|
38
|
+
return await _get_resource_from_db(kbid, uuid)
|
|
38
39
|
|
|
39
40
|
return await resource_cache.get(kbid, uuid)
|
|
40
41
|
|
|
41
42
|
|
|
42
|
-
async def _get_resource_from_db(kbid: str, uuid: str) ->
|
|
43
|
+
async def _get_resource_from_db(kbid: str, uuid: str) -> ResourceORM | None:
|
|
43
44
|
storage = await get_storage(service_name=SERVICE_NAME)
|
|
44
45
|
async with get_driver().ro_transaction() as transaction:
|
|
45
46
|
kb = KnowledgeBoxORM(transaction, storage, kbid)
|
|
@@ -81,7 +82,7 @@ async def get_paragraph(kbid: str, paragraph_id: str) -> str:
|
|
|
81
82
|
|
|
82
83
|
|
|
83
84
|
async def batchify(
|
|
84
|
-
producer: AsyncIterator[Any], size: int, batch_klass:
|
|
85
|
+
producer: AsyncIterator[Any], size: int, batch_klass: type[T]
|
|
85
86
|
) -> AsyncGenerator[T, None]:
|
|
86
87
|
# NOTE: we are supposing all protobuffers have a data field
|
|
87
88
|
batch = []
|