nucliadb-utils 6.9.7.post5482__py3-none-any.whl → 6.10.0.post5689__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nucliadb-utils might be problematic. Click here for more details.

Files changed (44) hide show
  1. nucliadb_utils/asyncio_utils.py +3 -3
  2. nucliadb_utils/audit/audit.py +32 -22
  3. nucliadb_utils/audit/basic.py +22 -23
  4. nucliadb_utils/audit/stream.py +31 -31
  5. nucliadb_utils/authentication.py +8 -10
  6. nucliadb_utils/cache/nats.py +10 -12
  7. nucliadb_utils/cache/pubsub.py +5 -4
  8. nucliadb_utils/cache/settings.py +2 -3
  9. nucliadb_utils/debug.py +2 -2
  10. nucliadb_utils/encryption/settings.py +1 -2
  11. nucliadb_utils/fastapi/openapi.py +1 -2
  12. nucliadb_utils/fastapi/versioning.py +10 -6
  13. nucliadb_utils/featureflagging.py +7 -4
  14. nucliadb_utils/grpc.py +3 -3
  15. nucliadb_utils/helpers.py +1 -1
  16. nucliadb_utils/nats.py +15 -16
  17. nucliadb_utils/nuclia_usage/utils/kb_usage_report.py +4 -5
  18. nucliadb_utils/run.py +1 -1
  19. nucliadb_utils/settings.py +40 -41
  20. nucliadb_utils/signals.py +3 -3
  21. nucliadb_utils/storages/azure.py +18 -18
  22. nucliadb_utils/storages/gcs.py +22 -21
  23. nucliadb_utils/storages/local.py +8 -8
  24. nucliadb_utils/storages/nuclia.py +1 -2
  25. nucliadb_utils/storages/object_store.py +6 -6
  26. nucliadb_utils/storages/s3.py +22 -22
  27. nucliadb_utils/storages/settings.py +7 -8
  28. nucliadb_utils/storages/storage.py +29 -45
  29. nucliadb_utils/storages/utils.py +2 -3
  30. nucliadb_utils/store.py +2 -2
  31. nucliadb_utils/tests/asyncbenchmark.py +8 -10
  32. nucliadb_utils/tests/azure.py +2 -1
  33. nucliadb_utils/tests/fixtures.py +3 -2
  34. nucliadb_utils/tests/gcs.py +3 -2
  35. nucliadb_utils/tests/local.py +2 -1
  36. nucliadb_utils/tests/nats.py +1 -1
  37. nucliadb_utils/tests/s3.py +2 -1
  38. nucliadb_utils/transaction.py +16 -18
  39. nucliadb_utils/utilities.py +22 -24
  40. {nucliadb_utils-6.9.7.post5482.dist-info → nucliadb_utils-6.10.0.post5689.dist-info}/METADATA +5 -5
  41. nucliadb_utils-6.10.0.post5689.dist-info/RECORD +59 -0
  42. nucliadb_utils-6.9.7.post5482.dist-info/RECORD +0 -59
  43. {nucliadb_utils-6.9.7.post5482.dist-info → nucliadb_utils-6.10.0.post5689.dist-info}/WHEEL +0 -0
  44. {nucliadb_utils-6.9.7.post5482.dist-info → nucliadb_utils-6.10.0.post5689.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,7 @@
19
19
  #
20
20
  import asyncio
21
21
  from collections.abc import Coroutine
22
- from typing import Any, Optional
22
+ from typing import Any
23
23
 
24
24
  from nucliadb_utils import logger
25
25
 
@@ -30,7 +30,7 @@ class ConcurrentRunner:
30
30
  Returns the results of the coroutines in the order they were scheduled.
31
31
  """
32
32
 
33
- def __init__(self, max_tasks: Optional[int] = None):
33
+ def __init__(self, max_tasks: int | None = None):
34
34
  self._tasks: list[asyncio.Task] = []
35
35
  self.max_tasks = asyncio.Semaphore(max_tasks) if max_tasks is not None else None
36
36
 
@@ -62,7 +62,7 @@ class ConcurrentRunner:
62
62
  return results
63
63
 
64
64
 
65
- async def run_concurrently(tasks: list[Coroutine], max_concurrent: Optional[int] = None) -> list[Any]:
65
+ async def run_concurrently(tasks: list[Coroutine], max_concurrent: int | None = None) -> list[Any]:
66
66
  """
67
67
  Runs a list of coroutines concurrently, with a maximum number of tasks running.
68
68
  Returns the results of the coroutines in the order they were scheduled.
@@ -17,7 +17,6 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
- from typing import List, Optional
21
20
 
22
21
  from google.protobuf.timestamp_pb2 import Timestamp
23
22
  from nidx_protos.nodereader_pb2 import SearchRequest
@@ -34,12 +33,12 @@ class AuditStorage:
34
33
  *,
35
34
  kbid: str,
36
35
  audit_type: AuditRequest.AuditType.Value, # type: ignore
37
- when: Optional[Timestamp] = None,
38
- user: Optional[str] = None,
39
- origin: Optional[str] = None,
40
- rid: Optional[str] = None,
41
- field_metadata: Optional[List[FieldID]] = None,
42
- audit_fields: Optional[List[AuditField]] = None,
36
+ when: Timestamp | None = None,
37
+ user: str | None = None,
38
+ origin: str | None = None,
39
+ rid: str | None = None,
40
+ field_metadata: list[FieldID] | None = None,
41
+ audit_fields: list[AuditField] | None = None,
43
42
  ):
44
43
  raise NotImplementedError
45
44
 
@@ -70,7 +69,7 @@ class AuditStorage:
70
69
  search: SearchRequest,
71
70
  timeit: float,
72
71
  resources: int,
73
- retrieval_rephrased_question: Optional[str] = None,
72
+ retrieval_rephrased_question: str | None = None,
74
73
  ):
75
74
  raise NotImplementedError
76
75
 
@@ -81,22 +80,33 @@ class AuditStorage:
81
80
  client: int,
82
81
  origin: str,
83
82
  question: str,
84
- rephrased_question: Optional[str],
85
- retrieval_rephrased_question: Optional[str],
86
- chat_context: List[ChatContext],
87
- retrieved_context: List[RetrievedContext],
88
- answer: Optional[str],
89
- reasoning: Optional[str],
90
- learning_id: Optional[str],
83
+ rephrased_question: str | None,
84
+ retrieval_rephrased_question: str | None,
85
+ chat_context: list[ChatContext],
86
+ retrieved_context: list[RetrievedContext],
87
+ answer: str | None,
88
+ reasoning: str | None,
89
+ learning_id: str | None,
91
90
  status_code: int,
92
- model: Optional[str],
93
- rephrase_time: Optional[float] = None,
94
- generative_answer_time: Optional[float] = None,
95
- generative_answer_first_chunk_time: Optional[float] = None,
96
- generative_reasoning_first_chunk_time: Optional[float] = None,
91
+ model: str | None,
92
+ rephrase_time: float | None = None,
93
+ generative_answer_time: float | None = None,
94
+ generative_answer_first_chunk_time: float | None = None,
95
+ generative_reasoning_first_chunk_time: float | None = None,
97
96
  ):
98
97
  raise NotImplementedError
99
98
 
99
+ def retrieve(
100
+ self,
101
+ kbid: str,
102
+ user: str,
103
+ client: int,
104
+ origin: str,
105
+ retrieval_time: float,
106
+ ):
107
+ # TODO(decoupled-ask): implement audit for /retrieve
108
+ ...
109
+
100
110
  def report_storage(self, kbid: str, paragraphs: int, fields: int, bytes: int):
101
111
  raise NotImplementedError
102
112
 
@@ -120,7 +130,7 @@ class AuditStorage:
120
130
  learning_id: str,
121
131
  good: bool,
122
132
  task: int,
123
- feedback: Optional[str],
124
- text_block_id: Optional[str],
133
+ feedback: str | None,
134
+ text_block_id: str | None,
125
135
  ):
126
136
  raise NotImplementedError
@@ -17,7 +17,6 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
- from typing import List, Optional
21
20
 
22
21
  from google.protobuf.timestamp_pb2 import Timestamp
23
22
  from nidx_protos.nodereader_pb2 import SearchRequest
@@ -34,19 +33,19 @@ class BasicAuditStorage(AuditStorage):
34
33
  self.initialized = True
35
34
 
36
35
  def message_to_str(self, message: BrokerMessage) -> str:
37
- return f"{message.type}+{message.multiid}+{message.audit.user}+{message.kbid}+{message.uuid}+{message.audit.when.ToJsonString()}+{message.audit.origin}+{message.audit.source}" # noqa
36
+ return f"{message.type}+{message.multiid}+{message.audit.user}+{message.kbid}+{message.uuid}+{message.audit.when.ToJsonString()}+{message.audit.origin}+{message.audit.source}"
38
37
 
39
38
  async def report_and_send(
40
39
  self,
41
40
  *,
42
41
  kbid: str,
43
42
  audit_type: AuditRequest.AuditType.Value, # type: ignore
44
- when: Optional[Timestamp] = None,
45
- user: Optional[str] = None,
46
- origin: Optional[str] = None,
47
- rid: Optional[str] = None,
48
- field_metadata: Optional[List[FieldID]] = None,
49
- audit_fields: Optional[List[AuditField]] = None,
43
+ when: Timestamp | None = None,
44
+ user: str | None = None,
45
+ origin: str | None = None,
46
+ rid: str | None = None,
47
+ field_metadata: list[FieldID] | None = None,
48
+ audit_fields: list[AuditField] | None = None,
50
49
  ):
51
50
  logger.debug(f"AUDIT {audit_type} {kbid} {user} {origin} {rid} {audit_fields}")
52
51
 
@@ -71,7 +70,7 @@ class BasicAuditStorage(AuditStorage):
71
70
  search: SearchRequest,
72
71
  timeit: float,
73
72
  resources: int,
74
- retrieval_rephrased_question: Optional[str] = None,
73
+ retrieval_rephrased_question: str | None = None,
75
74
  ):
76
75
  logger.debug(f"SEARCH {kbid} {user} {origin} ''{search}'' {timeit} {resources}")
77
76
 
@@ -82,19 +81,19 @@ class BasicAuditStorage(AuditStorage):
82
81
  client_type: int,
83
82
  origin: str,
84
83
  question: str,
85
- rephrased_question: Optional[str],
86
- retrieval_rephrased_question: Optional[str],
87
- chat_context: List[ChatContext],
88
- retrieved_context: List[RetrievedContext],
89
- answer: Optional[str],
90
- reasoning: Optional[str],
91
- learning_id: Optional[str],
84
+ rephrased_question: str | None,
85
+ retrieval_rephrased_question: str | None,
86
+ chat_context: list[ChatContext],
87
+ retrieved_context: list[RetrievedContext],
88
+ answer: str | None,
89
+ reasoning: str | None,
90
+ learning_id: str | None,
92
91
  status_code: int,
93
- model: Optional[str],
94
- rephrase_time: Optional[float] = None,
95
- generative_answer_time: Optional[float] = None,
96
- generative_answer_first_chunk_time: Optional[float] = None,
97
- generative_reasoning_first_chunk_time: Optional[float] = None,
92
+ model: str | None,
93
+ rephrase_time: float | None = None,
94
+ generative_answer_time: float | None = None,
95
+ generative_answer_first_chunk_time: float | None = None,
96
+ generative_reasoning_first_chunk_time: float | None = None,
98
97
  ):
99
98
  logger.debug(f"CHAT {kbid} {user} {origin}")
100
99
 
@@ -121,7 +120,7 @@ class BasicAuditStorage(AuditStorage):
121
120
  learning_id: str,
122
121
  good: bool,
123
122
  task: int,
124
- feedback: Optional[str],
125
- text_block_id: Optional[str],
123
+ feedback: str | None,
124
+ text_block_id: str | None,
126
125
  ):
127
126
  logger.debug(f"FEEDBACK {kbid} {user} {client_type} {origin}")
@@ -21,8 +21,8 @@ import asyncio
21
21
  import contextvars
22
22
  import json
23
23
  import time
24
+ from collections.abc import Callable
24
25
  from datetime import datetime, timezone
25
- from typing import Callable, List, Optional
26
26
 
27
27
  import backoff
28
28
  import mmh3
@@ -71,17 +71,17 @@ class RequestContext:
71
71
  self.path: str = ""
72
72
 
73
73
 
74
- request_context_var = contextvars.ContextVar[Optional[RequestContext]]("request_context", default=None)
74
+ request_context_var = contextvars.ContextVar[RequestContext | None]("request_context", default=None)
75
75
 
76
76
 
77
- def get_trace_id() -> Optional[str]:
77
+ def get_trace_id() -> str | None:
78
78
  span = get_current_span()
79
79
  if span is None:
80
80
  return None
81
81
  return format_trace_id(span.get_span_context().trace_id)
82
82
 
83
83
 
84
- def get_request_context() -> Optional[RequestContext]:
84
+ def get_request_context() -> RequestContext | None:
85
85
  return request_context_var.get()
86
86
 
87
87
 
@@ -103,7 +103,7 @@ def fill_audit_search_request(audit: AuditSearchRequest, request: SearchRequest)
103
103
 
104
104
 
105
105
  class AuditMiddleware(BaseHTTPMiddleware):
106
- def __init__(self, app: ASGIApp, audit_utility_getter: Callable[[], Optional[AuditStorage]]) -> None:
106
+ def __init__(self, app: ASGIApp, audit_utility_getter: Callable[[], AuditStorage | None]) -> None:
107
107
  self.audit_utility_getter = audit_utility_getter
108
108
  super().__init__(app)
109
109
 
@@ -155,17 +155,17 @@ KB_USAGE_STREAM_SUBJECT = "kb-usage.nuclia_db"
155
155
 
156
156
 
157
157
  class StreamAuditStorage(AuditStorage):
158
- task: Optional[asyncio.Task]
158
+ task: asyncio.Task | None
159
159
  initialized: bool
160
160
  queue: asyncio.Queue
161
161
 
162
162
  def __init__(
163
163
  self,
164
- nats_servers: List[str],
164
+ nats_servers: list[str],
165
165
  nats_target: str,
166
166
  partitions: int,
167
167
  seed: int,
168
- nats_creds: Optional[str] = None,
168
+ nats_creds: str | None = None,
169
169
  service: str = "nucliadb.audit",
170
170
  ):
171
171
  self.nats_servers = nats_servers
@@ -186,10 +186,10 @@ class StreamAuditStorage(AuditStorage):
186
186
 
187
187
  async def reconnected_cb(self):
188
188
  # See who we are connected to on reconnect.
189
- logger.info("Got reconnected to NATS {url}".format(url=self.nc.connected_url))
189
+ logger.info(f"Got reconnected to NATS {self.nc.connected_url}")
190
190
 
191
191
  async def error_cb(self, e):
192
- logger.error("There was an error connecting to NATS audit: {}".format(e), exc_info=True)
192
+ logger.error(f"There was an error connecting to NATS audit: {e}", exc_info=True)
193
193
 
194
194
  async def closed_cb(self):
195
195
  logger.info("Connection is closed on NATS")
@@ -269,12 +269,12 @@ class StreamAuditStorage(AuditStorage):
269
269
  *,
270
270
  kbid: str,
271
271
  audit_type: AuditRequest.AuditType.Value, # type: ignore
272
- when: Optional[Timestamp] = None,
273
- user: Optional[str] = None,
274
- origin: Optional[str] = None,
275
- rid: Optional[str] = None,
276
- field_metadata: Optional[List[FieldID]] = None,
277
- audit_fields: Optional[List[AuditField]] = None,
272
+ when: Timestamp | None = None,
273
+ user: str | None = None,
274
+ origin: str | None = None,
275
+ rid: str | None = None,
276
+ field_metadata: list[FieldID] | None = None,
277
+ audit_fields: list[AuditField] | None = None,
278
278
  ):
279
279
  auditrequest = AuditRequest()
280
280
 
@@ -369,7 +369,7 @@ class StreamAuditStorage(AuditStorage):
369
369
  search: SearchRequest,
370
370
  timeit: float,
371
371
  resources: int,
372
- retrieval_rephrased_question: Optional[str] = None,
372
+ retrieval_rephrased_question: str | None = None,
373
373
  ):
374
374
  context = get_request_context()
375
375
  if context is None:
@@ -420,19 +420,19 @@ class StreamAuditStorage(AuditStorage):
420
420
  client_type: int,
421
421
  origin: str,
422
422
  question: str,
423
- rephrased_question: Optional[str],
424
- retrieval_rephrased_question: Optional[str],
425
- chat_context: List[ChatContext],
426
- retrieved_context: List[RetrievedContext],
427
- answer: Optional[str],
428
- reasoning: Optional[str],
429
- learning_id: Optional[str],
423
+ rephrased_question: str | None,
424
+ retrieval_rephrased_question: str | None,
425
+ chat_context: list[ChatContext],
426
+ retrieved_context: list[RetrievedContext],
427
+ answer: str | None,
428
+ reasoning: str | None,
429
+ learning_id: str | None,
430
430
  status_code: int,
431
- model: Optional[str],
432
- rephrase_time: Optional[float] = None,
433
- generative_answer_time: Optional[float] = None,
434
- generative_answer_first_chunk_time: Optional[float] = None,
435
- generative_reasoning_first_chunk_time: Optional[float] = None,
431
+ model: str | None,
432
+ rephrase_time: float | None = None,
433
+ generative_answer_time: float | None = None,
434
+ generative_answer_first_chunk_time: float | None = None,
435
+ generative_reasoning_first_chunk_time: float | None = None,
436
436
  ):
437
437
  rcontext = get_request_context()
438
438
  if rcontext is None:
@@ -482,8 +482,8 @@ class StreamAuditStorage(AuditStorage):
482
482
  learning_id: str,
483
483
  good: bool,
484
484
  task: int,
485
- feedback: Optional[str],
486
- text_block_id: Optional[str],
485
+ feedback: str | None,
486
+ text_block_id: str | None,
487
487
  ):
488
488
  rcontext = get_request_context()
489
489
  if rcontext is None:
@@ -17,12 +17,10 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
 
20
- import asyncio
21
20
  import functools
22
21
  import inspect
23
22
  import typing
24
23
  from enum import Enum
25
- from typing import Optional, Tuple
26
24
 
27
25
  from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser
28
26
  from starlette.exceptions import HTTPException
@@ -32,7 +30,7 @@ from starlette.websockets import WebSocket
32
30
 
33
31
 
34
32
  class NucliaUser(BaseUser):
35
- def __init__(self, username: str, security_groups: Optional[list[str]] = None) -> None:
33
+ def __init__(self, username: str, security_groups: list[str] | None = None) -> None:
36
34
  self.username = username
37
35
  self._security_groups = security_groups
38
36
 
@@ -45,7 +43,7 @@ class NucliaUser(BaseUser):
45
43
  return self.username
46
44
 
47
45
  @property
48
- def security_groups(self) -> Optional[list[str]]:
46
+ def security_groups(self) -> list[str] | None:
49
47
  return self._security_groups
50
48
 
51
49
 
@@ -63,7 +61,7 @@ class NucliaCloudAuthenticationBackend(AuthenticationBackend):
63
61
  self.user_header = user_header
64
62
  self.security_groups_header = security_groups_header
65
63
 
66
- async def authenticate(self, request: HTTPConnection) -> Optional[Tuple[AuthCredentials, BaseUser]]:
64
+ async def authenticate(self, request: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
67
65
  if self.roles_header not in request.headers:
68
66
  return None
69
67
  else:
@@ -74,9 +72,9 @@ class NucliaCloudAuthenticationBackend(AuthenticationBackend):
74
72
  if self.user_header in request.headers:
75
73
  user = request.headers[self.user_header]
76
74
 
77
- raw_security_groups: Optional[str] = request.headers.get(self.security_groups_header)
75
+ raw_security_groups: str | None = request.headers.get(self.security_groups_header)
78
76
 
79
- security_groups: Optional[list[str]] = None
77
+ security_groups: list[str] | None = None
80
78
  if raw_security_groups is not None:
81
79
  security_groups = raw_security_groups.split(";")
82
80
 
@@ -99,9 +97,9 @@ def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bo
99
97
 
100
98
 
101
99
  def requires(
102
- scopes: typing.Union[str, typing.Sequence[str]],
100
+ scopes: str | typing.Sequence[str],
103
101
  status_code: int = 403,
104
- redirect: Optional[str] = None,
102
+ redirect: str | None = None,
105
103
  ) -> typing.Callable:
106
104
  # As a fastapi requirement, custom Enum classes have to inherit also from
107
105
  # string, so we MUST check for Enum before str
@@ -137,7 +135,7 @@ def requires(
137
135
 
138
136
  return websocket_wrapper
139
137
 
140
- elif asyncio.iscoroutinefunction(func):
138
+ elif inspect.iscoroutinefunction(func):
141
139
  # Handle async request/response functions.
142
140
  @functools.wraps(func)
143
141
  async def async_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
@@ -1,4 +1,3 @@
1
- # -*- coding: utf-8 -*-
2
1
  # Copyright (C) 2021 Bosutech XXI S.L.
3
2
  #
4
3
  # nucliadb is offered under the AGPL v3.0 and as commercial software.
@@ -23,7 +22,6 @@ import functools
23
22
  import os
24
23
  import uuid
25
24
  from inspect import iscoroutinefunction
26
- from typing import Dict, List, Optional, Union
27
25
 
28
26
  import nats
29
27
  import nats.errors
@@ -54,16 +52,16 @@ async def wait_for_it(future: asyncio.Future, msg):
54
52
  class NatsPubsub(PubSubDriver[Msg]):
55
53
  _jetstream = None
56
54
  _jsm = None
57
- _subscriptions: Dict[str, Subscription]
55
+ _subscriptions: dict[str, Subscription]
58
56
  async_callback = True
59
57
 
60
58
  def __init__(
61
59
  self,
62
60
  thread: bool = False,
63
- name: Optional[str] = "natsutility",
61
+ name: str | None = "natsutility",
64
62
  timeout: float = 2.0,
65
- hosts: Optional[List[str]] = None,
66
- user_credentials_file: Optional[str] = None,
63
+ hosts: list[str] | None = None,
64
+ user_credentials_file: str | None = None,
67
65
  ):
68
66
  self._hosts = hosts or []
69
67
  self._timeout = timeout
@@ -73,11 +71,11 @@ class NatsPubsub(PubSubDriver[Msg]):
73
71
  self._uuid = os.environ.get("HOSTNAME", uuid.uuid4().hex)
74
72
  self.initialized = False
75
73
  self.lock = asyncio.Lock()
76
- self.nc: Union[Client, NatsClientTelemetry, None] = None
74
+ self.nc: Client | NatsClientTelemetry | None = None
77
75
  self.user_credentials_file = user_credentials_file
78
76
 
79
77
  @property
80
- def jetstream(self) -> Union[JetStreamContext, JetStreamContextTelemetry]:
78
+ def jetstream(self) -> JetStreamContext | JetStreamContextTelemetry:
81
79
  if self.nc is None:
82
80
  raise AttributeError("NC not initialized")
83
81
  if self._jetstream is None:
@@ -154,10 +152,10 @@ class NatsPubsub(PubSubDriver[Msg]):
154
152
 
155
153
  async def reconnected_cb(self):
156
154
  # See who we are connected to on reconnect.
157
- logger.info("Got reconnected NATS to {url}".format(url=self.nc.connected_url.netloc))
155
+ logger.info(f"Got reconnected NATS to {self.nc.connected_url.netloc}")
158
156
 
159
157
  async def error_cb(self, e):
160
- logger.info("There was an error connecting to NATS {}".format(e), exc_info=True)
158
+ logger.info(f"There was an error connecting to NATS {e}", exc_info=True)
161
159
 
162
160
  async def closed_cb(self):
163
161
  logger.info("Connection is closed to NATS")
@@ -173,7 +171,7 @@ class NatsPubsub(PubSubDriver[Msg]):
173
171
  else:
174
172
  raise ErrConnectionClosed("Could not subscribe")
175
173
 
176
- async def subscribe(self, handler: Callback, key, group="", subscription_id: Optional[str] = None):
174
+ async def subscribe(self, handler: Callback, key, group="", subscription_id: str | None = None):
177
175
  if subscription_id is None:
178
176
  subscription_id = key
179
177
 
@@ -189,7 +187,7 @@ class NatsPubsub(PubSubDriver[Msg]):
189
187
  else:
190
188
  raise ErrConnectionClosed("Could not subscribe")
191
189
 
192
- async def unsubscribe(self, key: str, subscription_id: Optional[str] = None):
190
+ async def unsubscribe(self, key: str, subscription_id: str | None = None):
193
191
  if subscription_id is None:
194
192
  subscription_id = key
195
193
 
@@ -17,7 +17,8 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
- from typing import Awaitable, Callable, Generic, Optional, TypeVar
20
+ from collections.abc import Awaitable, Callable
21
+ from typing import Generic, TypeVar
21
22
 
22
23
  T = TypeVar("T")
23
24
 
@@ -37,15 +38,15 @@ class PubSubDriver(Generic[T]):
37
38
  async def publish(self, channel_name: str, data: bytes):
38
39
  raise NotImplementedError()
39
40
 
40
- async def unsubscribe(self, key: str, subscription_id: Optional[str] = None):
41
+ async def unsubscribe(self, key: str, subscription_id: str | None = None):
41
42
  raise NotImplementedError()
42
43
 
43
44
  async def subscribe(
44
45
  self,
45
46
  handler: Callback,
46
47
  key: str,
47
- group: Optional[str] = None,
48
- subscription_id: Optional[str] = None,
48
+ group: str | None = None,
49
+ subscription_id: str | None = None,
49
50
  ):
50
51
  raise NotImplementedError()
51
52
 
@@ -17,14 +17,13 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
 
20
- from typing import List, Optional
21
20
 
22
21
  from pydantic_settings import BaseSettings
23
22
 
24
23
 
25
24
  class Settings(BaseSettings):
26
- cache_pubsub_nats_url: List[str] = []
27
- cache_pubsub_nats_auth: Optional[str] = None
25
+ cache_pubsub_nats_url: list[str] = []
26
+ cache_pubsub_nats_auth: str | None = None
28
27
 
29
28
 
30
29
  settings = Settings()
nucliadb_utils/debug.py CHANGED
@@ -54,7 +54,7 @@ def display_top(snapshot, key_type="lineno", limit=10): # pragma: no cover
54
54
  print(f"Top {limit} lines")
55
55
  for index, stat in enumerate(top_stats[:limit], 1):
56
56
  frame = stat.traceback[0]
57
- print("#%s: %s:%s: %.1f KiB" % (index, frame.filename, frame.lineno, stat.size / 1024))
57
+ print(f"#{index}: {frame.filename}:{frame.lineno}: {stat.size / 1024:.1f} KiB")
58
58
  line = linecache.getline(frame.filename, frame.lineno).strip()
59
59
  if line:
60
60
  print(" %s" % line)
@@ -62,6 +62,6 @@ def display_top(snapshot, key_type="lineno", limit=10): # pragma: no cover
62
62
  other = top_stats[limit:]
63
63
  if other:
64
64
  size = sum(stat.size for stat in other)
65
- print("%s other: %.1f KiB" % (len(other), size / 1024))
65
+ print(f"{len(other)} other: {size / 1024:.1f} KiB")
66
66
  total = sum(stat.size for stat in top_stats)
67
67
  print("Total allocated size: %.1f KiB" % (total / 1024))
@@ -16,14 +16,13 @@
16
16
  #
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
- from typing import Optional
20
19
 
21
20
  from pydantic import Field
22
21
  from pydantic_settings import BaseSettings
23
22
 
24
23
 
25
24
  class EncryptionSettings(BaseSettings):
26
- encryption_secret_key: Optional[str] = Field(
25
+ encryption_secret_key: str | None = Field(
27
26
  default=None,
28
27
  title="Encryption Secret Key",
29
28
  description="""Secret key used for encryption and decryption of sensitive data in the database.
@@ -17,7 +17,6 @@
17
17
  # You should have received a copy of the GNU Affero General Public License
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
 
20
- from typing import Union
21
20
 
22
21
  from fastapi.applications import FastAPI
23
22
  from starlette.routing import Mount
@@ -34,7 +33,7 @@ def format_scopes(scope_list):
34
33
  return "\n".join(f"- `{scope}`" for scope in scope_list)
35
34
 
36
35
 
37
- def extend_openapi(app: Union[FastAPI, Mount]): # pragma: no cover
36
+ def extend_openapi(app: FastAPI | Mount): # pragma: no cover
38
37
  for route in app.routes:
39
38
  # mypy complains about BaseRoute not having endpoint and
40
39
  # description attributes, but routes passed here always have
@@ -20,9 +20,11 @@
20
20
  # This code is inspired by fastapi_versioning 1/3/2022 with MIT licence
21
21
 
22
22
  from collections import defaultdict
23
- from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, cast
23
+ from collections.abc import Callable
24
+ from typing import Any, Sequence, TypeVar, cast
24
25
 
25
26
  from fastapi import FastAPI
27
+ from fastapi.middleware import Middleware
26
28
  from fastapi.routing import APIRoute
27
29
  from starlette.routing import BaseRoute
28
30
 
@@ -39,8 +41,8 @@ def version(major: int, minor: int = 0) -> Callable[[CallableT], CallableT]: #
39
41
 
40
42
  def version_to_route(
41
43
  route: BaseRoute,
42
- default_version: Tuple[int, int],
43
- ) -> Tuple[Tuple[int, int], APIRoute]: # pragma: no cover
44
+ default_version: tuple[int, int],
45
+ ) -> tuple[tuple[int, int], APIRoute]: # pragma: no cover
44
46
  api_route = cast(APIRoute, route)
45
47
  version = getattr(api_route.endpoint, "_api_version", default_version)
46
48
  return version, api_route
@@ -50,17 +52,19 @@ def VersionedFastAPI(
50
52
  app: FastAPI,
51
53
  version_format: str = "{major}.{minor}",
52
54
  prefix_format: str = "/v{major}_{minor}",
53
- default_version: Tuple[int, int] = (1, 0),
55
+ default_version: tuple[int, int] = (1, 0),
54
56
  enable_latest: bool = False,
55
- kwargs: Optional[Dict[str, object]] = None,
57
+ middleware: Sequence[Middleware] | None = None,
58
+ kwargs: dict[str, object] | None = None,
56
59
  ) -> FastAPI: # pragma: no cover
57
60
  kwargs = kwargs or {}
58
61
 
59
62
  parent_app = FastAPI(
60
63
  title=app.title,
64
+ middleware=middleware,
61
65
  **kwargs, # type: ignore
62
66
  )
63
- version_route_mapping: Dict[Tuple[int, int], List[APIRoute]] = defaultdict(list)
67
+ version_route_mapping: dict[tuple[int, int], list[APIRoute]] = defaultdict(list)
64
68
  version_routes = [version_to_route(route, default_version) for route in app.routes]
65
69
 
66
70
  for version, route in version_routes: