nucliadb-utils 6.9.1.post5229__py3-none-any.whl → 6.10.0.post5732__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nucliadb-utils might be problematic. Click here for more details.

Files changed (45) hide show
  1. nucliadb_utils/asyncio_utils.py +3 -3
  2. nucliadb_utils/audit/audit.py +41 -31
  3. nucliadb_utils/audit/basic.py +22 -23
  4. nucliadb_utils/audit/stream.py +31 -31
  5. nucliadb_utils/authentication.py +8 -10
  6. nucliadb_utils/cache/nats.py +10 -12
  7. nucliadb_utils/cache/pubsub.py +5 -4
  8. nucliadb_utils/cache/settings.py +2 -3
  9. nucliadb_utils/const.py +1 -1
  10. nucliadb_utils/debug.py +2 -2
  11. nucliadb_utils/encryption/settings.py +1 -2
  12. nucliadb_utils/fastapi/openapi.py +1 -2
  13. nucliadb_utils/fastapi/versioning.py +10 -6
  14. nucliadb_utils/featureflagging.py +10 -4
  15. nucliadb_utils/grpc.py +3 -3
  16. nucliadb_utils/helpers.py +1 -1
  17. nucliadb_utils/nats.py +15 -16
  18. nucliadb_utils/nuclia_usage/utils/kb_usage_report.py +4 -5
  19. nucliadb_utils/run.py +1 -1
  20. nucliadb_utils/settings.py +40 -41
  21. nucliadb_utils/signals.py +3 -3
  22. nucliadb_utils/storages/azure.py +34 -21
  23. nucliadb_utils/storages/gcs.py +22 -21
  24. nucliadb_utils/storages/local.py +8 -8
  25. nucliadb_utils/storages/nuclia.py +1 -2
  26. nucliadb_utils/storages/object_store.py +6 -6
  27. nucliadb_utils/storages/s3.py +23 -23
  28. nucliadb_utils/storages/settings.py +7 -8
  29. nucliadb_utils/storages/storage.py +29 -45
  30. nucliadb_utils/storages/utils.py +2 -3
  31. nucliadb_utils/store.py +2 -2
  32. nucliadb_utils/tests/asyncbenchmark.py +8 -10
  33. nucliadb_utils/tests/azure.py +2 -1
  34. nucliadb_utils/tests/fixtures.py +3 -2
  35. nucliadb_utils/tests/gcs.py +3 -2
  36. nucliadb_utils/tests/local.py +2 -1
  37. nucliadb_utils/tests/nats.py +1 -1
  38. nucliadb_utils/tests/s3.py +2 -1
  39. nucliadb_utils/transaction.py +16 -18
  40. nucliadb_utils/utilities.py +22 -24
  41. {nucliadb_utils-6.9.1.post5229.dist-info → nucliadb_utils-6.10.0.post5732.dist-info}/METADATA +6 -6
  42. nucliadb_utils-6.10.0.post5732.dist-info/RECORD +59 -0
  43. nucliadb_utils-6.9.1.post5229.dist-info/RECORD +0 -59
  44. {nucliadb_utils-6.9.1.post5229.dist-info → nucliadb_utils-6.10.0.post5732.dist-info}/WHEEL +0 -0
  45. {nucliadb_utils-6.9.1.post5229.dist-info → nucliadb_utils-6.10.0.post5732.dist-info}/top_level.txt +0 -0
@@ -20,9 +20,11 @@
20
20
  # This code is inspired by fastapi_versioning 1/3/2022 with MIT licence
21
21
 
22
22
  from collections import defaultdict
23
- from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, cast
23
+ from collections.abc import Callable
24
+ from typing import Any, Sequence, TypeVar, cast
24
25
 
25
26
  from fastapi import FastAPI
27
+ from fastapi.middleware import Middleware
26
28
  from fastapi.routing import APIRoute
27
29
  from starlette.routing import BaseRoute
28
30
 
@@ -39,8 +41,8 @@ def version(major: int, minor: int = 0) -> Callable[[CallableT], CallableT]: #
39
41
 
40
42
  def version_to_route(
41
43
  route: BaseRoute,
42
- default_version: Tuple[int, int],
43
- ) -> Tuple[Tuple[int, int], APIRoute]: # pragma: no cover
44
+ default_version: tuple[int, int],
45
+ ) -> tuple[tuple[int, int], APIRoute]: # pragma: no cover
44
46
  api_route = cast(APIRoute, route)
45
47
  version = getattr(api_route.endpoint, "_api_version", default_version)
46
48
  return version, api_route
@@ -50,17 +52,19 @@ def VersionedFastAPI(
50
52
  app: FastAPI,
51
53
  version_format: str = "{major}.{minor}",
52
54
  prefix_format: str = "/v{major}_{minor}",
53
- default_version: Tuple[int, int] = (1, 0),
55
+ default_version: tuple[int, int] = (1, 0),
54
56
  enable_latest: bool = False,
55
- kwargs: Optional[Dict[str, object]] = None,
57
+ middleware: Sequence[Middleware] | None = None,
58
+ kwargs: dict[str, object] | None = None,
56
59
  ) -> FastAPI: # pragma: no cover
57
60
  kwargs = kwargs or {}
58
61
 
59
62
  parent_app = FastAPI(
60
63
  title=app.title,
64
+ middleware=middleware,
61
65
  **kwargs, # type: ignore
62
66
  )
63
- version_route_mapping: Dict[Tuple[int, int], List[APIRoute]] = defaultdict(list)
67
+ version_route_mapping: dict[tuple[int, int], list[APIRoute]] = defaultdict(list)
64
68
  version_routes = [version_to_route(route, default_version) for route in app.routes]
65
69
 
66
70
  for version, route in version_routes:
@@ -18,7 +18,7 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
 
20
20
  import json
21
- from typing import Any, Optional
21
+ from typing import Any
22
22
 
23
23
  import mrflagly
24
24
  import pydantic_settings
@@ -28,7 +28,10 @@ from nucliadb_utils.settings import nuclia_settings, running_settings
28
28
 
29
29
 
30
30
  class Settings(pydantic_settings.BaseSettings):
31
- flag_settings_url: Optional[str] = None
31
+ flag_settings_url: str | None = None
32
+
33
+ # temporary flag to test this FF enabled/disabled easily
34
+ disable_ask_decoupled_ff: bool = False
32
35
 
33
36
 
34
37
  DEFAULT_FLAG_DATA: dict[str, Any] = {
@@ -45,7 +48,10 @@ DEFAULT_FLAG_DATA: dict[str, Any] = {
45
48
  "rollout": 0,
46
49
  "variants": {"environment": ["local"]},
47
50
  },
48
- const.Features.REBALANCE_ENABLED: {"rollout": 0, "variants": {"environment": ["local"]}},
51
+ const.Features.ASK_DECOUPLED: {
52
+ "rollout": 0,
53
+ "variants": {"environment": [] if Settings().disable_ask_decoupled_ff else ["local"]},
54
+ },
49
55
  }
50
56
 
51
57
 
@@ -57,7 +63,7 @@ class FlagService:
57
63
  else:
58
64
  self.flag_service = mrflagly.FlagService(url=settings.flag_settings_url)
59
65
 
60
- def enabled(self, flag_key: str, default: bool = False, context: Optional[dict] = None) -> bool:
66
+ def enabled(self, flag_key: str, default: bool = False, context: dict | None = None) -> bool:
61
67
  if context is None:
62
68
  context = {}
63
69
  context["environment"] = running_settings.running_environment
nucliadb_utils/grpc.py CHANGED
@@ -19,7 +19,6 @@
19
19
 
20
20
  import json
21
21
  import logging
22
- from typing import Optional
23
22
 
24
23
  from grpc import ChannelCredentials, aio
25
24
 
@@ -58,7 +57,7 @@ RETRY_OPTIONS = [
58
57
  def get_traced_grpc_channel(
59
58
  address: str,
60
59
  service_name: str,
61
- credentials: Optional[ChannelCredentials] = None,
60
+ credentials: ChannelCredentials | None = None,
62
61
  variant: str = "",
63
62
  max_send_message: int = 100,
64
63
  ) -> aio.Channel:
@@ -75,7 +74,8 @@ def get_traced_grpc_channel(
75
74
  options = [
76
75
  ("grpc.max_receive_message_length", max_send_message * 1024 * 1024),
77
76
  ("grpc.max_send_message_length", max_send_message * 1024 * 1024),
78
- ] + RETRY_OPTIONS
77
+ *RETRY_OPTIONS,
78
+ ]
79
79
  channel = aio.insecure_channel(address, options=options)
80
80
  return channel
81
81
 
nucliadb_utils/helpers.py CHANGED
@@ -19,7 +19,7 @@
19
19
  #
20
20
  import asyncio
21
21
  import logging
22
- from typing import AsyncGenerator, Awaitable, Callable
22
+ from collections.abc import AsyncGenerator, Awaitable, Callable
23
23
 
24
24
  from nucliadb_telemetry.errors import capture_exception
25
25
 
nucliadb_utils/nats.py CHANGED
@@ -21,8 +21,9 @@ import asyncio
21
21
  import logging
22
22
  import sys
23
23
  import time
24
+ from collections.abc import Awaitable, Callable
24
25
  from functools import cached_property, partial
25
- from typing import Any, Awaitable, Callable, Optional, Union
26
+ from typing import Any
26
27
 
27
28
  import nats
28
29
  import nats.errors
@@ -67,7 +68,7 @@ class NatsMessageProgressUpdater(MessageProgressUpdater):
67
68
 
68
69
 
69
70
  class NatsConnectionManager:
70
- _nc: Union[NATSClient, NatsClientTelemetry]
71
+ _nc: NATSClient | NatsClientTelemetry
71
72
  _subscriptions: list[tuple[Subscription, Callable[[], Awaitable[None]]]]
72
73
  _pull_subscriptions: list[
73
74
  tuple[
@@ -81,8 +82,8 @@ class NatsConnectionManager:
81
82
  *,
82
83
  service_name: str,
83
84
  nats_servers: list[str],
84
- nats_creds: Optional[str] = None,
85
- pull_utilization_metrics: Optional[Counter] = None,
85
+ nats_creds: str | None = None,
86
+ pull_utilization_metrics: Counter | None = None,
86
87
  ):
87
88
  self._service_name = service_name
88
89
  self._nats_servers = nats_servers
@@ -91,9 +92,9 @@ class NatsConnectionManager:
91
92
  self._pull_subscriptions = []
92
93
  self._lock = asyncio.Lock()
93
94
  self._healthy = True
94
- self._last_unhealthy: Optional[float] = None
95
+ self._last_unhealthy: float | None = None
95
96
  self._needs_reconnection = False
96
- self._reconnect_task: Optional[asyncio.Task] = None
97
+ self._reconnect_task: asyncio.Task | None = None
97
98
  self._expected_subscriptions: set[str] = set()
98
99
  self._initialized = False
99
100
  self.pull_utilization_metrics = pull_utilization_metrics
@@ -274,11 +275,11 @@ class NatsConnectionManager:
274
275
  logger.info("Connection is closed on NATS")
275
276
 
276
277
  @property
277
- def nc(self) -> Union[NATSClient, NatsClientTelemetry]:
278
+ def nc(self) -> NATSClient | NatsClientTelemetry:
278
279
  return self._nc
279
280
 
280
281
  @cached_property
281
- def js(self) -> Union[JetStreamContext, JetStreamContextTelemetry]:
282
+ def js(self) -> JetStreamContext | JetStreamContextTelemetry:
282
283
  return get_traced_jetstream(self._nc, self._service_name)
283
284
 
284
285
  async def subscribe(
@@ -291,7 +292,7 @@ class NatsConnectionManager:
291
292
  subscription_lost_cb: Callable[[], Awaitable[None]],
292
293
  flow_control: bool = False,
293
294
  manual_ack: bool = True,
294
- config: Optional[nats.js.api.ConsumerConfig] = None,
295
+ config: nats.js.api.ConsumerConfig | None = None,
295
296
  ) -> Subscription:
296
297
  sub = await self.js.subscribe(
297
298
  subject=subject,
@@ -314,8 +315,8 @@ class NatsConnectionManager:
314
315
  stream: str,
315
316
  cb: Callable[[Msg], Awaitable[None]],
316
317
  subscription_lost_cb: Callable[[], Awaitable[None]],
317
- durable: Optional[str] = None,
318
- config: Optional[nats.js.api.ConsumerConfig] = None,
318
+ durable: str | None = None,
319
+ config: nats.js.api.ConsumerConfig | None = None,
319
320
  ) -> JetStreamContext.PullSubscription:
320
321
  wrapped_cb: Callable[[Msg], Awaitable[None]]
321
322
  if isinstance(self.js, JetStreamContextTelemetry):
@@ -370,9 +371,7 @@ class NatsConnectionManager:
370
371
 
371
372
  return psub
372
373
 
373
- async def _remove_subscription(
374
- self, subscription: Union[Subscription, JetStreamContext.PullSubscription]
375
- ):
374
+ async def _remove_subscription(self, subscription: Subscription | JetStreamContext.PullSubscription):
376
375
  async with self._lock:
377
376
  for index, (sub, _) in enumerate(self._subscriptions):
378
377
  if sub is not subscription:
@@ -391,7 +390,7 @@ class NatsConnectionManager:
391
390
  pass
392
391
  return
393
392
 
394
- async def unsubscribe(self, subscription: Union[Subscription, JetStreamContext.PullSubscription]):
393
+ async def unsubscribe(self, subscription: Subscription | JetStreamContext.PullSubscription):
395
394
  await subscription.unsubscribe()
396
395
  await self._remove_subscription(subscription)
397
396
 
@@ -403,7 +402,7 @@ class NatsConnectionManager:
403
402
  while True:
404
403
  await asyncio.sleep(30)
405
404
 
406
- existing_subs = set(sub._consumer for sub, _, _, _ in self._pull_subscriptions)
405
+ existing_subs = {sub._consumer for sub, _, _, _ in self._pull_subscriptions}
407
406
  missing_subs = self._expected_subscriptions - existing_subs
408
407
  if missing_subs:
409
408
  logger.warning(f"Some NATS subscriptions are missing {missing_subs}")
@@ -22,7 +22,6 @@ import logging
22
22
  from collections.abc import Iterable
23
23
  from contextlib import suppress
24
24
  from datetime import datetime, timezone
25
- from typing import Optional
26
25
 
27
26
  from nats.js.client import JetStreamContext
28
27
 
@@ -90,14 +89,14 @@ class KbUsageReportUtility:
90
89
  def send_kb_usage(
91
90
  self,
92
91
  service: Service,
93
- account_id: Optional[str],
94
- kb_id: Optional[str],
92
+ account_id: str | None,
93
+ kb_id: str | None,
95
94
  kb_source: KBSource,
96
95
  processes: Iterable[Process] = (),
97
96
  predicts: Iterable[Predict] = (),
98
97
  searches: Iterable[Search] = (),
99
- storage: Optional[Storage] = None,
100
- activity_log_match: Optional[ActivityLogMatch] = None,
98
+ storage: Storage | None = None,
99
+ activity_log_match: ActivityLogMatch | None = None,
101
100
  ):
102
101
  usage = KbUsage()
103
102
  usage.service = service # type: ignore
nucliadb_utils/run.py CHANGED
@@ -21,7 +21,7 @@ import asyncio
21
21
  import inspect
22
22
  import logging
23
23
  import signal
24
- from typing import Awaitable, Callable
24
+ from collections.abc import Awaitable, Callable
25
25
 
26
26
  logger = logging.getLogger(__name__)
27
27
 
@@ -18,7 +18,6 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
 
20
20
  from enum import Enum
21
- from typing import Dict, List, Optional
22
21
 
23
22
  from pydantic import AliasChoices, Field, model_validator
24
23
  from pydantic_settings import BaseSettings
@@ -26,7 +25,7 @@ from pydantic_settings import BaseSettings
26
25
 
27
26
  class RunningSettings(BaseSettings):
28
27
  debug: bool = False
29
- sentry_url: Optional[str] = None
28
+ sentry_url: str | None = None
30
29
  running_environment: str = Field(
31
30
  default="local",
32
31
  validation_alias=AliasChoices("environment", "running_environment"),
@@ -42,7 +41,7 @@ running_settings = RunningSettings()
42
41
 
43
42
 
44
43
  class HTTPSettings(BaseSettings):
45
- cors_origins: List[str] = ["*"]
44
+ cors_origins: list[str] = ["*"]
46
45
 
47
46
 
48
47
  http_settings = HTTPSettings()
@@ -70,47 +69,47 @@ class StorageSettings(BaseSettings):
70
69
  default=FileBackendConfig.NOT_SET, description="File backend storage type"
71
70
  )
72
71
 
73
- gcs_base64_creds: Optional[str] = Field(
72
+ gcs_base64_creds: str | None = Field(
74
73
  default=None,
75
- description="GCS JSON credentials of a service account encoded in Base64: https://cloud.google.com/iam/docs/service-account-overview", # noqa
74
+ description="GCS JSON credentials of a service account encoded in Base64: https://cloud.google.com/iam/docs/service-account-overview",
76
75
  )
77
- gcs_bucket: Optional[str] = Field(
76
+ gcs_bucket: str | None = Field(
78
77
  default=None,
79
78
  description="GCS Bucket name where files are stored: https://cloud.google.com/storage/docs/buckets",
80
79
  )
81
- gcs_location: Optional[str] = Field(
80
+ gcs_location: str | None = Field(
82
81
  default=None,
83
82
  description="GCS Bucket location: https://cloud.google.com/storage/docs/locations",
84
83
  )
85
- gcs_project: Optional[str] = Field(
84
+ gcs_project: str | None = Field(
86
85
  default=None,
87
- description="Google Cloud Project ID: https://cloud.google.com/resource-manager/docs/creating-managing-projects", # noqa
86
+ description="Google Cloud Project ID: https://cloud.google.com/resource-manager/docs/creating-managing-projects",
88
87
  )
89
- gcs_bucket_labels: Dict[str, str] = Field(
88
+ gcs_bucket_labels: dict[str, str] = Field(
90
89
  default={},
91
- description="Map of labels with which GCS buckets will be labeled with: https://cloud.google.com/storage/docs/tags-and-labels", # noqa
90
+ description="Map of labels with which GCS buckets will be labeled with: https://cloud.google.com/storage/docs/tags-and-labels",
92
91
  )
93
92
  gcs_endpoint_url: str = "https://www.googleapis.com"
94
93
 
95
- s3_client_id: Optional[str] = None
96
- s3_client_secret: Optional[str] = None
94
+ s3_client_id: str | None = None
95
+ s3_client_secret: str | None = None
97
96
  s3_ssl: bool = True
98
97
  s3_verify_ssl: bool = True
99
98
  s3_max_pool_connections: int = 30
100
- s3_endpoint: Optional[str] = None
101
- s3_region_name: Optional[str] = None
102
- s3_kms_key_id: Optional[str] = None
103
- s3_bucket: Optional[str] = Field(default=None, description="KnowledgeBox S3 bucket name template")
104
- s3_bucket_tags: Dict[str, str] = Field(
99
+ s3_endpoint: str | None = None
100
+ s3_region_name: str | None = None
101
+ s3_kms_key_id: str | None = None
102
+ s3_bucket: str | None = Field(default=None, description="KnowledgeBox S3 bucket name template")
103
+ s3_bucket_tags: dict[str, str] = Field(
105
104
  default={},
106
- description="Map of tags with which S3 buckets will be tagged with: https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutBucketTagging.html", # noqa
105
+ description="Map of tags with which S3 buckets will be tagged with: https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutBucketTagging.html",
107
106
  )
108
107
 
109
- local_files: Optional[str] = Field(
108
+ local_files: str | None = Field(
110
109
  default=None,
111
110
  description="If using LOCAL `file_backend` storage, directory where files should be stored",
112
111
  )
113
- local_indexing_bucket: Optional[str] = Field(
112
+ local_indexing_bucket: str | None = Field(
114
113
  default="indexer",
115
114
  description="If using LOCAL `file_backend` storage, subdirectory where indexing data is stored",
116
115
  )
@@ -119,28 +118,28 @@ class StorageSettings(BaseSettings):
119
118
  description="Number of days that uploaded files are kept in Nulia's processing engine",
120
119
  )
121
120
 
122
- azure_account_url: Optional[str] = Field(
121
+ azure_account_url: str | None = Field(
123
122
  default=None,
124
- description="Azure Account URL. The driver implementation uses Azure's default credential authentication method: https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.defaultazurecredential?view=azure-python", # noqa
123
+ description="Azure Account URL. The driver implementation uses Azure's default credential authentication method: https://learn.microsoft.com/en-us/python/api/azure-identity/azure.identity.defaultazurecredential?view=azure-python",
125
124
  examples=["https://<storageaccountname>.blob.core.windows.net"],
126
125
  )
127
126
 
128
- azure_kb_account_url: Optional[str] = Field(
127
+ azure_kb_account_url: str | None = Field(
129
128
  default=None,
130
- description="Azure Account URL for KB containers. If unspecified, uses `azure_account_url`", # noqa
129
+ description="Azure Account URL for KB containers. If unspecified, uses `azure_account_url`",
131
130
  examples=["https://<storageaccountname>.blob.core.windows.net"],
132
131
  )
133
132
 
134
133
  # For testing purposes: Azurite docker image requires a connection string as it
135
134
  # doesn't support Azure's default credential authentication method
136
- azure_connection_string: Optional[str] = None
135
+ azure_connection_string: str | None = None
137
136
 
138
137
 
139
138
  storage_settings = StorageSettings()
140
139
 
141
140
 
142
141
  class NucliaSettings(BaseSettings):
143
- nuclia_service_account: Optional[str] = None
142
+ nuclia_service_account: str | None = None
144
143
  nuclia_public_url: str = "https://{zone}.nuclia.cloud"
145
144
  nuclia_processing_cluster_url: str = "http://processing-api.processing.svc.cluster.local:8080"
146
145
  nuclia_inner_predict_url: str = "http://predict.learning.svc.cluster.local:8080"
@@ -149,7 +148,7 @@ class NucliaSettings(BaseSettings):
149
148
  nuclia_zone: str = "europe-1"
150
149
  onprem: bool = True
151
150
 
152
- nuclia_jwt_key: Optional[str] = None
151
+ nuclia_jwt_key: str | None = None
153
152
  nuclia_hash_seed: int = 42
154
153
  nuclia_partitions: int = 1
155
154
 
@@ -157,7 +156,7 @@ class NucliaSettings(BaseSettings):
157
156
  dummy_predict: bool = False
158
157
  dummy_learning_services: bool = False
159
158
  local_predict: bool = False
160
- local_predict_headers: Dict[str, str] = {}
159
+ local_predict_headers: dict[str, str] = {}
161
160
 
162
161
  @model_validator(mode="before")
163
162
  @classmethod
@@ -171,15 +170,15 @@ nuclia_settings = NucliaSettings()
171
170
 
172
171
 
173
172
  class NucliaDBSettings(BaseSettings):
174
- nucliadb_ingest: Optional[str] = "ingest-orm-grpc.nucliadb.svc.cluster.local:8030"
173
+ nucliadb_ingest: str | None = "ingest-orm-grpc.nucliadb.svc.cluster.local:8030"
175
174
 
176
175
 
177
176
  nucliadb_settings = NucliaDBSettings()
178
177
 
179
178
 
180
179
  class TransactionSettings(BaseSettings):
181
- transaction_jetstream_auth: Optional[str] = None
182
- transaction_jetstream_servers: List[str] = ["nats://localhost:4222"]
180
+ transaction_jetstream_auth: str | None = None
181
+ transaction_jetstream_servers: list[str] = ["nats://localhost:4222"]
183
182
  transaction_local: bool = False
184
183
  transaction_commit_timeout: int = Field(
185
184
  default=60, description="Transaction commit timeout in seconds"
@@ -190,10 +189,10 @@ transaction_settings = TransactionSettings()
190
189
 
191
190
 
192
191
  class IndexingSettings(BaseSettings):
193
- index_jetstream_servers: List[str] = []
194
- index_jetstream_auth: Optional[str] = None
192
+ index_jetstream_servers: list[str] = []
193
+ index_jetstream_auth: str | None = None
195
194
  index_local: bool = False
196
- index_nidx_subject: Optional[str] = None
195
+ index_nidx_subject: str | None = None
197
196
  index_searcher_refresh_interval: float = 1.0
198
197
 
199
198
 
@@ -202,9 +201,9 @@ indexing_settings = IndexingSettings()
202
201
 
203
202
  class AuditSettings(BaseSettings):
204
203
  audit_driver: str = "basic"
205
- audit_jetstream_target: Optional[str] = "audit.{partition}.{type}"
206
- audit_jetstream_servers: List[str] = []
207
- audit_jetstream_auth: Optional[str] = None
204
+ audit_jetstream_target: str | None = "audit.{partition}.{type}"
205
+ audit_jetstream_servers: list[str] = []
206
+ audit_jetstream_auth: str | None = None
208
207
  audit_partitions: int = 3
209
208
  audit_stream: str = "audit"
210
209
  audit_hash_seed: int = 1234
@@ -214,9 +213,9 @@ audit_settings = AuditSettings()
214
213
 
215
214
 
216
215
  class UsageSettings(BaseSettings):
217
- usage_jetstream_subject: Optional[str] = "kb-usage.nuclia_db"
218
- usage_jetstream_servers: List[str] = []
219
- usage_jetstream_auth: Optional[str] = None
216
+ usage_jetstream_subject: str | None = "kb-usage.nuclia_db"
217
+ usage_jetstream_servers: list[str] = []
218
+ usage_jetstream_auth: str | None = None
220
219
  usage_stream: str = "kb-usage"
221
220
 
222
221
 
nucliadb_utils/signals.py CHANGED
@@ -18,10 +18,10 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
 
20
20
  import asyncio
21
- from collections.abc import Awaitable
21
+ from collections.abc import Awaitable, Callable
22
22
  from enum import Enum
23
23
  from inspect import iscoroutinefunction
24
- from typing import Any, Callable, Type
24
+ from typing import Any
25
25
 
26
26
  from nucliadb_telemetry.errors import capture_exception
27
27
  from nucliadb_utils import logger
@@ -33,7 +33,7 @@ class ListenerPriority(Enum):
33
33
 
34
34
 
35
35
  class Signal:
36
- def __init__(self, payload_model: Type):
36
+ def __init__(self, payload_model: type):
37
37
  self.payload_model_type = payload_model
38
38
  self.callbacks: dict[str, tuple[Callable[..., Awaitable], int]] = {}
39
39
 
@@ -20,9 +20,10 @@
20
20
 
21
21
  from __future__ import annotations
22
22
 
23
+ import base64
23
24
  import logging
25
+ from collections.abc import AsyncGenerator, AsyncIterator
24
26
  from datetime import datetime
25
- from typing import AsyncGenerator, AsyncIterator, Optional, Union
26
27
 
27
28
  from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
28
29
  from azure.identity import DefaultAzureCredential
@@ -65,7 +66,7 @@ class AzureStorageField(StorageField):
65
66
  origin_bucket_name, origin_uri, destination_bucket_name, destination_uri
66
67
  )
67
68
 
68
- async def iter_data(self, range: Optional[Range] = None) -> AsyncGenerator[bytes, None]:
69
+ async def iter_data(self, range: Range | None = None) -> AsyncGenerator[bytes]:
69
70
  if self.field is not None:
70
71
  bucket = self.field.bucket_name
71
72
  key = self.field.uri
@@ -132,7 +133,7 @@ class AzureStorageField(StorageField):
132
133
  self.field.ClearField("upload_uri")
133
134
  self.field.ClearField("parts")
134
135
 
135
- async def exists(self) -> Optional[ObjectMetadata]:
136
+ async def exists(self) -> ObjectMetadata | None:
136
137
  key = None
137
138
  bucket = None
138
139
  if self.field is not None and self.field.uri != "":
@@ -170,9 +171,9 @@ class AzureStorage(Storage):
170
171
  self,
171
172
  account_url: str,
172
173
  kb_account_url: str,
173
- deadletter_bucket: Optional[str] = "deadletter",
174
- indexing_bucket: Optional[str] = "indexing",
175
- connection_string: Optional[str] = None,
174
+ deadletter_bucket: str | None = "deadletter",
175
+ indexing_bucket: str | None = "indexing",
176
+ connection_string: str | None = None,
176
177
  ):
177
178
  self.object_store = AzureObjectStore(account_url, connection_string=connection_string)
178
179
  self.kb_object_store = AzureObjectStore(kb_account_url, connection_string=connection_string)
@@ -185,7 +186,7 @@ class AzureStorage(Storage):
185
186
  else:
186
187
  return self.kb_object_store
187
188
 
188
- async def initialize(self, service_name: Optional[str] = None):
189
+ async def initialize(self, service_name: str | None = None):
189
190
  await self.object_store.initialize()
190
191
  await self.kb_object_store.initialize()
191
192
  for bucket in [
@@ -208,7 +209,7 @@ class AzureStorage(Storage):
208
209
  except KeyError:
209
210
  pass
210
211
 
211
- async def create_bucket(self, bucket_name: str, kbid: Optional[str] = None):
212
+ async def create_bucket(self, bucket_name: str, kbid: str | None = None):
212
213
  if await self.object_store_for_bucket(bucket_name).bucket_exists(bucket_name):
213
214
  return
214
215
  await self.object_store_for_bucket(bucket_name).bucket_create(bucket_name)
@@ -230,8 +231,8 @@ class AzureStorage(Storage):
230
231
  return await self.kb_object_store.bucket_delete(bucket_name)
231
232
 
232
233
  async def iterate_objects(
233
- self, bucket: str, prefix: str, start: Optional[str] = None
234
- ) -> AsyncGenerator[ObjectInfo, None]:
234
+ self, bucket: str, prefix: str, start: str | None = None
235
+ ) -> AsyncGenerator[ObjectInfo]:
235
236
  async for obj in self.object_store_for_bucket(bucket).iterate(bucket, prefix, start):
236
237
  yield obj
237
238
 
@@ -240,10 +241,10 @@ class AzureStorage(Storage):
240
241
 
241
242
 
242
243
  class AzureObjectStore(ObjectStore):
243
- def __init__(self, account_url: str, connection_string: Optional[str] = None):
244
+ def __init__(self, account_url: str, connection_string: str | None = None):
244
245
  self.account_url = account_url
245
246
  self.connection_string = connection_string
246
- self._service_client: Optional[BlobServiceClient] = None
247
+ self._service_client: BlobServiceClient | None = None
247
248
 
248
249
  @property
249
250
  def service_client(self) -> BlobServiceClient:
@@ -344,11 +345,11 @@ class AzureObjectStore(ObjectStore):
344
345
  self,
345
346
  bucket: str,
346
347
  key: str,
347
- data: Union[bytes, AsyncGenerator[bytes, None]],
348
+ data: bytes | AsyncGenerator[bytes],
348
349
  metadata: ObjectMetadata,
349
350
  ) -> None:
350
351
  container_client = self.service_client.get_container_client(bucket)
351
- length: Optional[int] = None
352
+ length: int | None = None
352
353
  if isinstance(data, bytes):
353
354
  length = len(data)
354
355
  metadata.size = length
@@ -384,8 +385,8 @@ class AzureObjectStore(ObjectStore):
384
385
  return await downloader.readall()
385
386
 
386
387
  async def download_stream(
387
- self, bucket: str, key: str, range: Optional[Range] = None
388
- ) -> AsyncGenerator[bytes, None]:
388
+ self, bucket: str, key: str, range: Range | None = None
389
+ ) -> AsyncGenerator[bytes]:
389
390
  range = range or Range()
390
391
  container_client = self.service_client.get_container_client(bucket)
391
392
  blob_client = container_client.get_blob_client(key)
@@ -405,8 +406,8 @@ class AzureObjectStore(ObjectStore):
405
406
  yield chunk
406
407
 
407
408
  async def iterate(
408
- self, bucket: str, prefix: str, start: Optional[str] = None
409
- ) -> AsyncGenerator[ObjectInfo, None]:
409
+ self, bucket: str, prefix: str, start: str | None = None
410
+ ) -> AsyncGenerator[ObjectInfo]:
410
411
  container_client = self.service_client.get_container_client(bucket)
411
412
  async for blob in container_client.list_blobs(name_starts_with=prefix):
412
413
  if start and blob.name <= start:
@@ -426,13 +427,20 @@ class AzureObjectStore(ObjectStore):
426
427
  @ops_observer.wrap({"type": "multipart_start"})
427
428
  async def upload_multipart_start(self, bucket: str, key: str, metadata: ObjectMetadata) -> None:
428
429
  container_client = self.service_client.get_container_client(bucket)
429
- custom_metadata = {key: str(value) for key, value in metadata.model_dump().items()}
430
+ custom_metadata = {
431
+ "base64_filename": base64.b64encode(metadata.filename.encode()).decode(),
432
+ "content_type": metadata.content_type,
433
+ "size": str(metadata.size),
434
+ }
430
435
  blob_client = container_client.get_blob_client(key)
436
+ safe_filename = (
437
+ metadata.filename.encode("ascii", "replace").decode().replace('"', "").replace("\n", "")
438
+ )
431
439
  await blob_client.create_append_blob(
432
440
  metadata=custom_metadata,
433
441
  content_settings=ContentSettings(
434
442
  content_type=metadata.content_type,
435
- content_disposition=f"attachment; filename={metadata.filename}",
443
+ content_disposition=f'attachment; filename="{safe_filename}"',
436
444
  ),
437
445
  )
438
446
 
@@ -460,7 +468,12 @@ def parse_object_metadata(properties: BlobProperties, key: str) -> ObjectMetadat
460
468
  size = int(custom_metadata_size)
461
469
  else:
462
470
  size = properties.size
463
- filename = custom_metadata.get("filename") or key.split("/")[-1]
471
+
472
+ b64_filename = custom_metadata.get("base64_filename")
473
+ if b64_filename:
474
+ filename = base64.b64decode(b64_filename.encode()).decode()
475
+ else:
476
+ filename = key.split("/")[-1]
464
477
  content_type = custom_metadata.get("content_type") or properties.content_settings.content_type or ""
465
478
  return ObjectMetadata(
466
479
  filename=filename,