nucliadb-utils 6.9.1.post5229__py3-none-any.whl → 6.10.0.post5732__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nucliadb-utils might be problematic. Click here for more details.
- nucliadb_utils/asyncio_utils.py +3 -3
- nucliadb_utils/audit/audit.py +41 -31
- nucliadb_utils/audit/basic.py +22 -23
- nucliadb_utils/audit/stream.py +31 -31
- nucliadb_utils/authentication.py +8 -10
- nucliadb_utils/cache/nats.py +10 -12
- nucliadb_utils/cache/pubsub.py +5 -4
- nucliadb_utils/cache/settings.py +2 -3
- nucliadb_utils/const.py +1 -1
- nucliadb_utils/debug.py +2 -2
- nucliadb_utils/encryption/settings.py +1 -2
- nucliadb_utils/fastapi/openapi.py +1 -2
- nucliadb_utils/fastapi/versioning.py +10 -6
- nucliadb_utils/featureflagging.py +10 -4
- nucliadb_utils/grpc.py +3 -3
- nucliadb_utils/helpers.py +1 -1
- nucliadb_utils/nats.py +15 -16
- nucliadb_utils/nuclia_usage/utils/kb_usage_report.py +4 -5
- nucliadb_utils/run.py +1 -1
- nucliadb_utils/settings.py +40 -41
- nucliadb_utils/signals.py +3 -3
- nucliadb_utils/storages/azure.py +34 -21
- nucliadb_utils/storages/gcs.py +22 -21
- nucliadb_utils/storages/local.py +8 -8
- nucliadb_utils/storages/nuclia.py +1 -2
- nucliadb_utils/storages/object_store.py +6 -6
- nucliadb_utils/storages/s3.py +23 -23
- nucliadb_utils/storages/settings.py +7 -8
- nucliadb_utils/storages/storage.py +29 -45
- nucliadb_utils/storages/utils.py +2 -3
- nucliadb_utils/store.py +2 -2
- nucliadb_utils/tests/asyncbenchmark.py +8 -10
- nucliadb_utils/tests/azure.py +2 -1
- nucliadb_utils/tests/fixtures.py +3 -2
- nucliadb_utils/tests/gcs.py +3 -2
- nucliadb_utils/tests/local.py +2 -1
- nucliadb_utils/tests/nats.py +1 -1
- nucliadb_utils/tests/s3.py +2 -1
- nucliadb_utils/transaction.py +16 -18
- nucliadb_utils/utilities.py +22 -24
- {nucliadb_utils-6.9.1.post5229.dist-info → nucliadb_utils-6.10.0.post5732.dist-info}/METADATA +6 -6
- nucliadb_utils-6.10.0.post5732.dist-info/RECORD +59 -0
- nucliadb_utils-6.9.1.post5229.dist-info/RECORD +0 -59
- {nucliadb_utils-6.9.1.post5229.dist-info → nucliadb_utils-6.10.0.post5732.dist-info}/WHEEL +0 -0
- {nucliadb_utils-6.9.1.post5229.dist-info → nucliadb_utils-6.10.0.post5732.dist-info}/top_level.txt +0 -0
|
@@ -20,9 +20,11 @@
|
|
|
20
20
|
# This code is inspired by fastapi_versioning 1/3/2022 with MIT licence
|
|
21
21
|
|
|
22
22
|
from collections import defaultdict
|
|
23
|
-
from
|
|
23
|
+
from collections.abc import Callable
|
|
24
|
+
from typing import Any, Sequence, TypeVar, cast
|
|
24
25
|
|
|
25
26
|
from fastapi import FastAPI
|
|
27
|
+
from fastapi.middleware import Middleware
|
|
26
28
|
from fastapi.routing import APIRoute
|
|
27
29
|
from starlette.routing import BaseRoute
|
|
28
30
|
|
|
@@ -39,8 +41,8 @@ def version(major: int, minor: int = 0) -> Callable[[CallableT], CallableT]: #
|
|
|
39
41
|
|
|
40
42
|
def version_to_route(
|
|
41
43
|
route: BaseRoute,
|
|
42
|
-
default_version:
|
|
43
|
-
) ->
|
|
44
|
+
default_version: tuple[int, int],
|
|
45
|
+
) -> tuple[tuple[int, int], APIRoute]: # pragma: no cover
|
|
44
46
|
api_route = cast(APIRoute, route)
|
|
45
47
|
version = getattr(api_route.endpoint, "_api_version", default_version)
|
|
46
48
|
return version, api_route
|
|
@@ -50,17 +52,19 @@ def VersionedFastAPI(
|
|
|
50
52
|
app: FastAPI,
|
|
51
53
|
version_format: str = "{major}.{minor}",
|
|
52
54
|
prefix_format: str = "/v{major}_{minor}",
|
|
53
|
-
default_version:
|
|
55
|
+
default_version: tuple[int, int] = (1, 0),
|
|
54
56
|
enable_latest: bool = False,
|
|
55
|
-
|
|
57
|
+
middleware: Sequence[Middleware] | None = None,
|
|
58
|
+
kwargs: dict[str, object] | None = None,
|
|
56
59
|
) -> FastAPI: # pragma: no cover
|
|
57
60
|
kwargs = kwargs or {}
|
|
58
61
|
|
|
59
62
|
parent_app = FastAPI(
|
|
60
63
|
title=app.title,
|
|
64
|
+
middleware=middleware,
|
|
61
65
|
**kwargs, # type: ignore
|
|
62
66
|
)
|
|
63
|
-
version_route_mapping:
|
|
67
|
+
version_route_mapping: dict[tuple[int, int], list[APIRoute]] = defaultdict(list)
|
|
64
68
|
version_routes = [version_to_route(route, default_version) for route in app.routes]
|
|
65
69
|
|
|
66
70
|
for version, route in version_routes:
|
|
@@ -18,7 +18,7 @@
|
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
|
|
20
20
|
import json
|
|
21
|
-
from typing import Any
|
|
21
|
+
from typing import Any
|
|
22
22
|
|
|
23
23
|
import mrflagly
|
|
24
24
|
import pydantic_settings
|
|
@@ -28,7 +28,10 @@ from nucliadb_utils.settings import nuclia_settings, running_settings
|
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class Settings(pydantic_settings.BaseSettings):
|
|
31
|
-
flag_settings_url:
|
|
31
|
+
flag_settings_url: str | None = None
|
|
32
|
+
|
|
33
|
+
# temporary flag to test this FF enabled/disabled easily
|
|
34
|
+
disable_ask_decoupled_ff: bool = False
|
|
32
35
|
|
|
33
36
|
|
|
34
37
|
DEFAULT_FLAG_DATA: dict[str, Any] = {
|
|
@@ -45,7 +48,10 @@ DEFAULT_FLAG_DATA: dict[str, Any] = {
|
|
|
45
48
|
"rollout": 0,
|
|
46
49
|
"variants": {"environment": ["local"]},
|
|
47
50
|
},
|
|
48
|
-
const.Features.
|
|
51
|
+
const.Features.ASK_DECOUPLED: {
|
|
52
|
+
"rollout": 0,
|
|
53
|
+
"variants": {"environment": [] if Settings().disable_ask_decoupled_ff else ["local"]},
|
|
54
|
+
},
|
|
49
55
|
}
|
|
50
56
|
|
|
51
57
|
|
|
@@ -57,7 +63,7 @@ class FlagService:
|
|
|
57
63
|
else:
|
|
58
64
|
self.flag_service = mrflagly.FlagService(url=settings.flag_settings_url)
|
|
59
65
|
|
|
60
|
-
def enabled(self, flag_key: str, default: bool = False, context:
|
|
66
|
+
def enabled(self, flag_key: str, default: bool = False, context: dict | None = None) -> bool:
|
|
61
67
|
if context is None:
|
|
62
68
|
context = {}
|
|
63
69
|
context["environment"] = running_settings.running_environment
|
nucliadb_utils/grpc.py
CHANGED
|
@@ -19,7 +19,6 @@
|
|
|
19
19
|
|
|
20
20
|
import json
|
|
21
21
|
import logging
|
|
22
|
-
from typing import Optional
|
|
23
22
|
|
|
24
23
|
from grpc import ChannelCredentials, aio
|
|
25
24
|
|
|
@@ -58,7 +57,7 @@ RETRY_OPTIONS = [
|
|
|
58
57
|
def get_traced_grpc_channel(
|
|
59
58
|
address: str,
|
|
60
59
|
service_name: str,
|
|
61
|
-
credentials:
|
|
60
|
+
credentials: ChannelCredentials | None = None,
|
|
62
61
|
variant: str = "",
|
|
63
62
|
max_send_message: int = 100,
|
|
64
63
|
) -> aio.Channel:
|
|
@@ -75,7 +74,8 @@ def get_traced_grpc_channel(
|
|
|
75
74
|
options = [
|
|
76
75
|
("grpc.max_receive_message_length", max_send_message * 1024 * 1024),
|
|
77
76
|
("grpc.max_send_message_length", max_send_message * 1024 * 1024),
|
|
78
|
-
|
|
77
|
+
*RETRY_OPTIONS,
|
|
78
|
+
]
|
|
79
79
|
channel = aio.insecure_channel(address, options=options)
|
|
80
80
|
return channel
|
|
81
81
|
|
nucliadb_utils/helpers.py
CHANGED
nucliadb_utils/nats.py
CHANGED
|
@@ -21,8 +21,9 @@ import asyncio
|
|
|
21
21
|
import logging
|
|
22
22
|
import sys
|
|
23
23
|
import time
|
|
24
|
+
from collections.abc import Awaitable, Callable
|
|
24
25
|
from functools import cached_property, partial
|
|
25
|
-
from typing import Any
|
|
26
|
+
from typing import Any
|
|
26
27
|
|
|
27
28
|
import nats
|
|
28
29
|
import nats.errors
|
|
@@ -67,7 +68,7 @@ class NatsMessageProgressUpdater(MessageProgressUpdater):
|
|
|
67
68
|
|
|
68
69
|
|
|
69
70
|
class NatsConnectionManager:
|
|
70
|
-
_nc:
|
|
71
|
+
_nc: NATSClient | NatsClientTelemetry
|
|
71
72
|
_subscriptions: list[tuple[Subscription, Callable[[], Awaitable[None]]]]
|
|
72
73
|
_pull_subscriptions: list[
|
|
73
74
|
tuple[
|
|
@@ -81,8 +82,8 @@ class NatsConnectionManager:
|
|
|
81
82
|
*,
|
|
82
83
|
service_name: str,
|
|
83
84
|
nats_servers: list[str],
|
|
84
|
-
nats_creds:
|
|
85
|
-
pull_utilization_metrics:
|
|
85
|
+
nats_creds: str | None = None,
|
|
86
|
+
pull_utilization_metrics: Counter | None = None,
|
|
86
87
|
):
|
|
87
88
|
self._service_name = service_name
|
|
88
89
|
self._nats_servers = nats_servers
|
|
@@ -91,9 +92,9 @@ class NatsConnectionManager:
|
|
|
91
92
|
self._pull_subscriptions = []
|
|
92
93
|
self._lock = asyncio.Lock()
|
|
93
94
|
self._healthy = True
|
|
94
|
-
self._last_unhealthy:
|
|
95
|
+
self._last_unhealthy: float | None = None
|
|
95
96
|
self._needs_reconnection = False
|
|
96
|
-
self._reconnect_task:
|
|
97
|
+
self._reconnect_task: asyncio.Task | None = None
|
|
97
98
|
self._expected_subscriptions: set[str] = set()
|
|
98
99
|
self._initialized = False
|
|
99
100
|
self.pull_utilization_metrics = pull_utilization_metrics
|
|
@@ -274,11 +275,11 @@ class NatsConnectionManager:
|
|
|
274
275
|
logger.info("Connection is closed on NATS")
|
|
275
276
|
|
|
276
277
|
@property
|
|
277
|
-
def nc(self) ->
|
|
278
|
+
def nc(self) -> NATSClient | NatsClientTelemetry:
|
|
278
279
|
return self._nc
|
|
279
280
|
|
|
280
281
|
@cached_property
|
|
281
|
-
def js(self) ->
|
|
282
|
+
def js(self) -> JetStreamContext | JetStreamContextTelemetry:
|
|
282
283
|
return get_traced_jetstream(self._nc, self._service_name)
|
|
283
284
|
|
|
284
285
|
async def subscribe(
|
|
@@ -291,7 +292,7 @@ class NatsConnectionManager:
|
|
|
291
292
|
subscription_lost_cb: Callable[[], Awaitable[None]],
|
|
292
293
|
flow_control: bool = False,
|
|
293
294
|
manual_ack: bool = True,
|
|
294
|
-
config:
|
|
295
|
+
config: nats.js.api.ConsumerConfig | None = None,
|
|
295
296
|
) -> Subscription:
|
|
296
297
|
sub = await self.js.subscribe(
|
|
297
298
|
subject=subject,
|
|
@@ -314,8 +315,8 @@ class NatsConnectionManager:
|
|
|
314
315
|
stream: str,
|
|
315
316
|
cb: Callable[[Msg], Awaitable[None]],
|
|
316
317
|
subscription_lost_cb: Callable[[], Awaitable[None]],
|
|
317
|
-
durable:
|
|
318
|
-
config:
|
|
318
|
+
durable: str | None = None,
|
|
319
|
+
config: nats.js.api.ConsumerConfig | None = None,
|
|
319
320
|
) -> JetStreamContext.PullSubscription:
|
|
320
321
|
wrapped_cb: Callable[[Msg], Awaitable[None]]
|
|
321
322
|
if isinstance(self.js, JetStreamContextTelemetry):
|
|
@@ -370,9 +371,7 @@ class NatsConnectionManager:
|
|
|
370
371
|
|
|
371
372
|
return psub
|
|
372
373
|
|
|
373
|
-
async def _remove_subscription(
|
|
374
|
-
self, subscription: Union[Subscription, JetStreamContext.PullSubscription]
|
|
375
|
-
):
|
|
374
|
+
async def _remove_subscription(self, subscription: Subscription | JetStreamContext.PullSubscription):
|
|
376
375
|
async with self._lock:
|
|
377
376
|
for index, (sub, _) in enumerate(self._subscriptions):
|
|
378
377
|
if sub is not subscription:
|
|
@@ -391,7 +390,7 @@ class NatsConnectionManager:
|
|
|
391
390
|
pass
|
|
392
391
|
return
|
|
393
392
|
|
|
394
|
-
async def unsubscribe(self, subscription:
|
|
393
|
+
async def unsubscribe(self, subscription: Subscription | JetStreamContext.PullSubscription):
|
|
395
394
|
await subscription.unsubscribe()
|
|
396
395
|
await self._remove_subscription(subscription)
|
|
397
396
|
|
|
@@ -403,7 +402,7 @@ class NatsConnectionManager:
|
|
|
403
402
|
while True:
|
|
404
403
|
await asyncio.sleep(30)
|
|
405
404
|
|
|
406
|
-
existing_subs =
|
|
405
|
+
existing_subs = {sub._consumer for sub, _, _, _ in self._pull_subscriptions}
|
|
407
406
|
missing_subs = self._expected_subscriptions - existing_subs
|
|
408
407
|
if missing_subs:
|
|
409
408
|
logger.warning(f"Some NATS subscriptions are missing {missing_subs}")
|
|
@@ -22,7 +22,6 @@ import logging
|
|
|
22
22
|
from collections.abc import Iterable
|
|
23
23
|
from contextlib import suppress
|
|
24
24
|
from datetime import datetime, timezone
|
|
25
|
-
from typing import Optional
|
|
26
25
|
|
|
27
26
|
from nats.js.client import JetStreamContext
|
|
28
27
|
|
|
@@ -90,14 +89,14 @@ class KbUsageReportUtility:
|
|
|
90
89
|
def send_kb_usage(
|
|
91
90
|
self,
|
|
92
91
|
service: Service,
|
|
93
|
-
account_id:
|
|
94
|
-
kb_id:
|
|
92
|
+
account_id: str | None,
|
|
93
|
+
kb_id: str | None,
|
|
95
94
|
kb_source: KBSource,
|
|
96
95
|
processes: Iterable[Process] = (),
|
|
97
96
|
predicts: Iterable[Predict] = (),
|
|
98
97
|
searches: Iterable[Search] = (),
|
|
99
|
-
storage:
|
|
100
|
-
activity_log_match:
|
|
98
|
+
storage: Storage | None = None,
|
|
99
|
+
activity_log_match: ActivityLogMatch | None = None,
|
|
101
100
|
):
|
|
102
101
|
usage = KbUsage()
|
|
103
102
|
usage.service = service # type: ignore
|
nucliadb_utils/run.py
CHANGED
nucliadb_utils/settings.py
CHANGED
|
@@ -18,7 +18,6 @@
|
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
|
|
20
20
|
from enum import Enum
|
|
21
|
-
from typing import Dict, List, Optional
|
|
22
21
|
|
|
23
22
|
from pydantic import AliasChoices, Field, model_validator
|
|
24
23
|
from pydantic_settings import BaseSettings
|
|
@@ -26,7 +25,7 @@ from pydantic_settings import BaseSettings
|
|
|
26
25
|
|
|
27
26
|
class RunningSettings(BaseSettings):
|
|
28
27
|
debug: bool = False
|
|
29
|
-
sentry_url:
|
|
28
|
+
sentry_url: str | None = None
|
|
30
29
|
running_environment: str = Field(
|
|
31
30
|
default="local",
|
|
32
31
|
validation_alias=AliasChoices("environment", "running_environment"),
|
|
@@ -42,7 +41,7 @@ running_settings = RunningSettings()
|
|
|
42
41
|
|
|
43
42
|
|
|
44
43
|
class HTTPSettings(BaseSettings):
|
|
45
|
-
cors_origins:
|
|
44
|
+
cors_origins: list[str] = ["*"]
|
|
46
45
|
|
|
47
46
|
|
|
48
47
|
http_settings = HTTPSettings()
|
|
@@ -70,47 +69,47 @@ class StorageSettings(BaseSettings):
|
|
|
70
69
|
default=FileBackendConfig.NOT_SET, description="File backend storage type"
|
|
71
70
|
)
|
|
72
71
|
|
|
73
|
-
gcs_base64_creds:
|
|
72
|
+
gcs_base64_creds: str | None = Field(
|
|
74
73
|
default=None,
|
|
75
|
-
description="GCS JSON credentials of a service account encoded in Base64: https://cloud.google.com/iam/docs/service-account-overview",
|
|
74
|
+
description="GCS JSON credentials of a service account encoded in Base64: https://cloud.google.com/iam/docs/service-account-overview",
|
|
76
75
|
)
|
|
77
|
-
gcs_bucket:
|
|
76
|
+
gcs_bucket: str | None = Field(
|
|
78
77
|
default=None,
|
|
79
78
|
description="GCS Bucket name where files are stored: https://cloud.google.com/storage/docs/buckets",
|
|
80
79
|
)
|
|
81
|
-
gcs_location:
|
|
80
|
+
gcs_location: str | None = Field(
|
|
82
81
|
default=None,
|
|
83
82
|
description="GCS Bucket location: https://cloud.google.com/storage/docs/locations",
|
|
84
83
|
)
|
|
85
|
-
gcs_project:
|
|
84
|
+
gcs_project: str | None = Field(
|
|
86
85
|
default=None,
|
|
87
|
-
description="Google Cloud Project ID: https://cloud.google.com/resource-manager/docs/creating-managing-projects",
|
|
86
|
+
description="Google Cloud Project ID: https://cloud.google.com/resource-manager/docs/creating-managing-projects",
|
|
88
87
|
)
|
|
89
|
-
gcs_bucket_labels:
|
|
88
|
+
gcs_bucket_labels: dict[str, str] = Field(
|
|
90
89
|
default={},
|
|
91
|
-
description="Map of labels with which GCS buckets will be labeled with: https://cloud.google.com/storage/docs/tags-and-labels",
|
|
90
|
+
description="Map of labels with which GCS buckets will be labeled with: https://cloud.google.com/storage/docs/tags-and-labels",
|
|
92
91
|
)
|
|
93
92
|
gcs_endpoint_url: str = "https://www.googleapis.com"
|
|
94
93
|
|
|
95
|
-
s3_client_id:
|
|
96
|
-
s3_client_secret:
|
|
94
|
+
s3_client_id: str | None = None
|
|
95
|
+
s3_client_secret: str | None = None
|
|
97
96
|
s3_ssl: bool = True
|
|
98
97
|
s3_verify_ssl: bool = True
|
|
99
98
|
s3_max_pool_connections: int = 30
|
|
100
|
-
s3_endpoint:
|
|
101
|
-
s3_region_name:
|
|
102
|
-
s3_kms_key_id:
|
|
103
|
-
s3_bucket:
|
|
104
|
-
s3_bucket_tags:
|
|
99
|
+
s3_endpoint: str | None = None
|
|
100
|
+
s3_region_name: str | None = None
|
|
101
|
+
s3_kms_key_id: str | None = None
|
|
102
|
+
s3_bucket: str | None = Field(default=None, description="KnowledgeBox S3 bucket name template")
|
|
103
|
+
s3_bucket_tags: dict[str, str] = Field(
|
|
105
104
|
default={},
|
|
106
|
-
description="Map of tags with which S3 buckets will be tagged with: https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutBucketTagging.html",
|
|
105
|
+
description="Map of tags with which S3 buckets will be tagged with: https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutBucketTagging.html",
|
|
107
106
|
)
|
|
108
107
|
|
|
109
|
-
local_files:
|
|
108
|
+
local_files: str | None = Field(
|
|
110
109
|
default=None,
|
|
111
110
|
description="If using LOCAL `file_backend` storage, directory where files should be stored",
|
|
112
111
|
)
|
|
113
|
-
local_indexing_bucket:
|
|
112
|
+
local_indexing_bucket: str | None = Field(
|
|
114
113
|
default="indexer",
|
|
115
114
|
description="If using LOCAL `file_backend` storage, subdirectory where indexing data is stored",
|
|
116
115
|
)
|
|
@@ -119,28 +118,28 @@ class StorageSettings(BaseSettings):
|
|
|
119
118
|
description="Number of days that uploaded files are kept in Nulia's processing engine",
|
|
120
119
|
)
|
|
121
120
|
|
|
122
|
-
azure_account_url:
|
|
121
|
+
azure_account_url: str | None = Field(
|
|
123
122
|
default=None,
|
|
124
|
-
description="Azure Account URL. The driver implementation uses Azure's default credential authentication method: https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.defaultazurecredential?view=azure-python",
|
|
123
|
+
description="Azure Account URL. The driver implementation uses Azure's default credential authentication method: https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.defaultazurecredential?view=azure-python",
|
|
125
124
|
examples=["https://<storageaccountname>.blob.core.windows.net"],
|
|
126
125
|
)
|
|
127
126
|
|
|
128
|
-
azure_kb_account_url:
|
|
127
|
+
azure_kb_account_url: str | None = Field(
|
|
129
128
|
default=None,
|
|
130
|
-
description="Azure Account URL for KB containers. If unspecified, uses `azure_account_url`",
|
|
129
|
+
description="Azure Account URL for KB containers. If unspecified, uses `azure_account_url`",
|
|
131
130
|
examples=["https://<storageaccountname>.blob.core.windows.net"],
|
|
132
131
|
)
|
|
133
132
|
|
|
134
133
|
# For testing purposes: Azurite docker image requires a connection string as it
|
|
135
134
|
# doesn't support Azure's default credential authentication method
|
|
136
|
-
azure_connection_string:
|
|
135
|
+
azure_connection_string: str | None = None
|
|
137
136
|
|
|
138
137
|
|
|
139
138
|
storage_settings = StorageSettings()
|
|
140
139
|
|
|
141
140
|
|
|
142
141
|
class NucliaSettings(BaseSettings):
|
|
143
|
-
nuclia_service_account:
|
|
142
|
+
nuclia_service_account: str | None = None
|
|
144
143
|
nuclia_public_url: str = "https://{zone}.nuclia.cloud"
|
|
145
144
|
nuclia_processing_cluster_url: str = "http://processing-api.processing.svc.cluster.local:8080"
|
|
146
145
|
nuclia_inner_predict_url: str = "http://predict.learning.svc.cluster.local:8080"
|
|
@@ -149,7 +148,7 @@ class NucliaSettings(BaseSettings):
|
|
|
149
148
|
nuclia_zone: str = "europe-1"
|
|
150
149
|
onprem: bool = True
|
|
151
150
|
|
|
152
|
-
nuclia_jwt_key:
|
|
151
|
+
nuclia_jwt_key: str | None = None
|
|
153
152
|
nuclia_hash_seed: int = 42
|
|
154
153
|
nuclia_partitions: int = 1
|
|
155
154
|
|
|
@@ -157,7 +156,7 @@ class NucliaSettings(BaseSettings):
|
|
|
157
156
|
dummy_predict: bool = False
|
|
158
157
|
dummy_learning_services: bool = False
|
|
159
158
|
local_predict: bool = False
|
|
160
|
-
local_predict_headers:
|
|
159
|
+
local_predict_headers: dict[str, str] = {}
|
|
161
160
|
|
|
162
161
|
@model_validator(mode="before")
|
|
163
162
|
@classmethod
|
|
@@ -171,15 +170,15 @@ nuclia_settings = NucliaSettings()
|
|
|
171
170
|
|
|
172
171
|
|
|
173
172
|
class NucliaDBSettings(BaseSettings):
|
|
174
|
-
nucliadb_ingest:
|
|
173
|
+
nucliadb_ingest: str | None = "ingest-orm-grpc.nucliadb.svc.cluster.local:8030"
|
|
175
174
|
|
|
176
175
|
|
|
177
176
|
nucliadb_settings = NucliaDBSettings()
|
|
178
177
|
|
|
179
178
|
|
|
180
179
|
class TransactionSettings(BaseSettings):
|
|
181
|
-
transaction_jetstream_auth:
|
|
182
|
-
transaction_jetstream_servers:
|
|
180
|
+
transaction_jetstream_auth: str | None = None
|
|
181
|
+
transaction_jetstream_servers: list[str] = ["nats://localhost:4222"]
|
|
183
182
|
transaction_local: bool = False
|
|
184
183
|
transaction_commit_timeout: int = Field(
|
|
185
184
|
default=60, description="Transaction commit timeout in seconds"
|
|
@@ -190,10 +189,10 @@ transaction_settings = TransactionSettings()
|
|
|
190
189
|
|
|
191
190
|
|
|
192
191
|
class IndexingSettings(BaseSettings):
|
|
193
|
-
index_jetstream_servers:
|
|
194
|
-
index_jetstream_auth:
|
|
192
|
+
index_jetstream_servers: list[str] = []
|
|
193
|
+
index_jetstream_auth: str | None = None
|
|
195
194
|
index_local: bool = False
|
|
196
|
-
index_nidx_subject:
|
|
195
|
+
index_nidx_subject: str | None = None
|
|
197
196
|
index_searcher_refresh_interval: float = 1.0
|
|
198
197
|
|
|
199
198
|
|
|
@@ -202,9 +201,9 @@ indexing_settings = IndexingSettings()
|
|
|
202
201
|
|
|
203
202
|
class AuditSettings(BaseSettings):
|
|
204
203
|
audit_driver: str = "basic"
|
|
205
|
-
audit_jetstream_target:
|
|
206
|
-
audit_jetstream_servers:
|
|
207
|
-
audit_jetstream_auth:
|
|
204
|
+
audit_jetstream_target: str | None = "audit.{partition}.{type}"
|
|
205
|
+
audit_jetstream_servers: list[str] = []
|
|
206
|
+
audit_jetstream_auth: str | None = None
|
|
208
207
|
audit_partitions: int = 3
|
|
209
208
|
audit_stream: str = "audit"
|
|
210
209
|
audit_hash_seed: int = 1234
|
|
@@ -214,9 +213,9 @@ audit_settings = AuditSettings()
|
|
|
214
213
|
|
|
215
214
|
|
|
216
215
|
class UsageSettings(BaseSettings):
|
|
217
|
-
usage_jetstream_subject:
|
|
218
|
-
usage_jetstream_servers:
|
|
219
|
-
usage_jetstream_auth:
|
|
216
|
+
usage_jetstream_subject: str | None = "kb-usage.nuclia_db"
|
|
217
|
+
usage_jetstream_servers: list[str] = []
|
|
218
|
+
usage_jetstream_auth: str | None = None
|
|
220
219
|
usage_stream: str = "kb-usage"
|
|
221
220
|
|
|
222
221
|
|
nucliadb_utils/signals.py
CHANGED
|
@@ -18,10 +18,10 @@
|
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
|
|
20
20
|
import asyncio
|
|
21
|
-
from collections.abc import Awaitable
|
|
21
|
+
from collections.abc import Awaitable, Callable
|
|
22
22
|
from enum import Enum
|
|
23
23
|
from inspect import iscoroutinefunction
|
|
24
|
-
from typing import Any
|
|
24
|
+
from typing import Any
|
|
25
25
|
|
|
26
26
|
from nucliadb_telemetry.errors import capture_exception
|
|
27
27
|
from nucliadb_utils import logger
|
|
@@ -33,7 +33,7 @@ class ListenerPriority(Enum):
|
|
|
33
33
|
|
|
34
34
|
|
|
35
35
|
class Signal:
|
|
36
|
-
def __init__(self, payload_model:
|
|
36
|
+
def __init__(self, payload_model: type):
|
|
37
37
|
self.payload_model_type = payload_model
|
|
38
38
|
self.callbacks: dict[str, tuple[Callable[..., Awaitable], int]] = {}
|
|
39
39
|
|
nucliadb_utils/storages/azure.py
CHANGED
|
@@ -20,9 +20,10 @@
|
|
|
20
20
|
|
|
21
21
|
from __future__ import annotations
|
|
22
22
|
|
|
23
|
+
import base64
|
|
23
24
|
import logging
|
|
25
|
+
from collections.abc import AsyncGenerator, AsyncIterator
|
|
24
26
|
from datetime import datetime
|
|
25
|
-
from typing import AsyncGenerator, AsyncIterator, Optional, Union
|
|
26
27
|
|
|
27
28
|
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
|
|
28
29
|
from azure.identity import DefaultAzureCredential
|
|
@@ -65,7 +66,7 @@ class AzureStorageField(StorageField):
|
|
|
65
66
|
origin_bucket_name, origin_uri, destination_bucket_name, destination_uri
|
|
66
67
|
)
|
|
67
68
|
|
|
68
|
-
async def iter_data(self, range:
|
|
69
|
+
async def iter_data(self, range: Range | None = None) -> AsyncGenerator[bytes]:
|
|
69
70
|
if self.field is not None:
|
|
70
71
|
bucket = self.field.bucket_name
|
|
71
72
|
key = self.field.uri
|
|
@@ -132,7 +133,7 @@ class AzureStorageField(StorageField):
|
|
|
132
133
|
self.field.ClearField("upload_uri")
|
|
133
134
|
self.field.ClearField("parts")
|
|
134
135
|
|
|
135
|
-
async def exists(self) ->
|
|
136
|
+
async def exists(self) -> ObjectMetadata | None:
|
|
136
137
|
key = None
|
|
137
138
|
bucket = None
|
|
138
139
|
if self.field is not None and self.field.uri != "":
|
|
@@ -170,9 +171,9 @@ class AzureStorage(Storage):
|
|
|
170
171
|
self,
|
|
171
172
|
account_url: str,
|
|
172
173
|
kb_account_url: str,
|
|
173
|
-
deadletter_bucket:
|
|
174
|
-
indexing_bucket:
|
|
175
|
-
connection_string:
|
|
174
|
+
deadletter_bucket: str | None = "deadletter",
|
|
175
|
+
indexing_bucket: str | None = "indexing",
|
|
176
|
+
connection_string: str | None = None,
|
|
176
177
|
):
|
|
177
178
|
self.object_store = AzureObjectStore(account_url, connection_string=connection_string)
|
|
178
179
|
self.kb_object_store = AzureObjectStore(kb_account_url, connection_string=connection_string)
|
|
@@ -185,7 +186,7 @@ class AzureStorage(Storage):
|
|
|
185
186
|
else:
|
|
186
187
|
return self.kb_object_store
|
|
187
188
|
|
|
188
|
-
async def initialize(self, service_name:
|
|
189
|
+
async def initialize(self, service_name: str | None = None):
|
|
189
190
|
await self.object_store.initialize()
|
|
190
191
|
await self.kb_object_store.initialize()
|
|
191
192
|
for bucket in [
|
|
@@ -208,7 +209,7 @@ class AzureStorage(Storage):
|
|
|
208
209
|
except KeyError:
|
|
209
210
|
pass
|
|
210
211
|
|
|
211
|
-
async def create_bucket(self, bucket_name: str, kbid:
|
|
212
|
+
async def create_bucket(self, bucket_name: str, kbid: str | None = None):
|
|
212
213
|
if await self.object_store_for_bucket(bucket_name).bucket_exists(bucket_name):
|
|
213
214
|
return
|
|
214
215
|
await self.object_store_for_bucket(bucket_name).bucket_create(bucket_name)
|
|
@@ -230,8 +231,8 @@ class AzureStorage(Storage):
|
|
|
230
231
|
return await self.kb_object_store.bucket_delete(bucket_name)
|
|
231
232
|
|
|
232
233
|
async def iterate_objects(
|
|
233
|
-
self, bucket: str, prefix: str, start:
|
|
234
|
-
) -> AsyncGenerator[ObjectInfo
|
|
234
|
+
self, bucket: str, prefix: str, start: str | None = None
|
|
235
|
+
) -> AsyncGenerator[ObjectInfo]:
|
|
235
236
|
async for obj in self.object_store_for_bucket(bucket).iterate(bucket, prefix, start):
|
|
236
237
|
yield obj
|
|
237
238
|
|
|
@@ -240,10 +241,10 @@ class AzureStorage(Storage):
|
|
|
240
241
|
|
|
241
242
|
|
|
242
243
|
class AzureObjectStore(ObjectStore):
|
|
243
|
-
def __init__(self, account_url: str, connection_string:
|
|
244
|
+
def __init__(self, account_url: str, connection_string: str | None = None):
|
|
244
245
|
self.account_url = account_url
|
|
245
246
|
self.connection_string = connection_string
|
|
246
|
-
self._service_client:
|
|
247
|
+
self._service_client: BlobServiceClient | None = None
|
|
247
248
|
|
|
248
249
|
@property
|
|
249
250
|
def service_client(self) -> BlobServiceClient:
|
|
@@ -344,11 +345,11 @@ class AzureObjectStore(ObjectStore):
|
|
|
344
345
|
self,
|
|
345
346
|
bucket: str,
|
|
346
347
|
key: str,
|
|
347
|
-
data:
|
|
348
|
+
data: bytes | AsyncGenerator[bytes],
|
|
348
349
|
metadata: ObjectMetadata,
|
|
349
350
|
) -> None:
|
|
350
351
|
container_client = self.service_client.get_container_client(bucket)
|
|
351
|
-
length:
|
|
352
|
+
length: int | None = None
|
|
352
353
|
if isinstance(data, bytes):
|
|
353
354
|
length = len(data)
|
|
354
355
|
metadata.size = length
|
|
@@ -384,8 +385,8 @@ class AzureObjectStore(ObjectStore):
|
|
|
384
385
|
return await downloader.readall()
|
|
385
386
|
|
|
386
387
|
async def download_stream(
|
|
387
|
-
self, bucket: str, key: str, range:
|
|
388
|
-
) -> AsyncGenerator[bytes
|
|
388
|
+
self, bucket: str, key: str, range: Range | None = None
|
|
389
|
+
) -> AsyncGenerator[bytes]:
|
|
389
390
|
range = range or Range()
|
|
390
391
|
container_client = self.service_client.get_container_client(bucket)
|
|
391
392
|
blob_client = container_client.get_blob_client(key)
|
|
@@ -405,8 +406,8 @@ class AzureObjectStore(ObjectStore):
|
|
|
405
406
|
yield chunk
|
|
406
407
|
|
|
407
408
|
async def iterate(
|
|
408
|
-
self, bucket: str, prefix: str, start:
|
|
409
|
-
) -> AsyncGenerator[ObjectInfo
|
|
409
|
+
self, bucket: str, prefix: str, start: str | None = None
|
|
410
|
+
) -> AsyncGenerator[ObjectInfo]:
|
|
410
411
|
container_client = self.service_client.get_container_client(bucket)
|
|
411
412
|
async for blob in container_client.list_blobs(name_starts_with=prefix):
|
|
412
413
|
if start and blob.name <= start:
|
|
@@ -426,13 +427,20 @@ class AzureObjectStore(ObjectStore):
|
|
|
426
427
|
@ops_observer.wrap({"type": "multipart_start"})
|
|
427
428
|
async def upload_multipart_start(self, bucket: str, key: str, metadata: ObjectMetadata) -> None:
|
|
428
429
|
container_client = self.service_client.get_container_client(bucket)
|
|
429
|
-
custom_metadata = {
|
|
430
|
+
custom_metadata = {
|
|
431
|
+
"base64_filename": base64.b64encode(metadata.filename.encode()).decode(),
|
|
432
|
+
"content_type": metadata.content_type,
|
|
433
|
+
"size": str(metadata.size),
|
|
434
|
+
}
|
|
430
435
|
blob_client = container_client.get_blob_client(key)
|
|
436
|
+
safe_filename = (
|
|
437
|
+
metadata.filename.encode("ascii", "replace").decode().replace('"', "").replace("\n", "")
|
|
438
|
+
)
|
|
431
439
|
await blob_client.create_append_blob(
|
|
432
440
|
metadata=custom_metadata,
|
|
433
441
|
content_settings=ContentSettings(
|
|
434
442
|
content_type=metadata.content_type,
|
|
435
|
-
content_disposition=f
|
|
443
|
+
content_disposition=f'attachment; filename="{safe_filename}"',
|
|
436
444
|
),
|
|
437
445
|
)
|
|
438
446
|
|
|
@@ -460,7 +468,12 @@ def parse_object_metadata(properties: BlobProperties, key: str) -> ObjectMetadat
|
|
|
460
468
|
size = int(custom_metadata_size)
|
|
461
469
|
else:
|
|
462
470
|
size = properties.size
|
|
463
|
-
|
|
471
|
+
|
|
472
|
+
b64_filename = custom_metadata.get("base64_filename")
|
|
473
|
+
if b64_filename:
|
|
474
|
+
filename = base64.b64decode(b64_filename.encode()).decode()
|
|
475
|
+
else:
|
|
476
|
+
filename = key.split("/")[-1]
|
|
464
477
|
content_type = custom_metadata.get("content_type") or properties.content_settings.content_type or ""
|
|
465
478
|
return ObjectMetadata(
|
|
466
479
|
filename=filename,
|