nucliadb-utils 6.9.6.post5473__py3-none-any.whl → 6.10.0.post5732__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nucliadb-utils might be problematic. Click here for more details.
- nucliadb_utils/asyncio_utils.py +3 -3
- nucliadb_utils/audit/audit.py +41 -31
- nucliadb_utils/audit/basic.py +22 -23
- nucliadb_utils/audit/stream.py +31 -31
- nucliadb_utils/authentication.py +8 -10
- nucliadb_utils/cache/nats.py +10 -12
- nucliadb_utils/cache/pubsub.py +5 -4
- nucliadb_utils/cache/settings.py +2 -3
- nucliadb_utils/debug.py +2 -2
- nucliadb_utils/encryption/settings.py +1 -2
- nucliadb_utils/fastapi/openapi.py +1 -2
- nucliadb_utils/fastapi/versioning.py +10 -6
- nucliadb_utils/featureflagging.py +7 -4
- nucliadb_utils/grpc.py +3 -3
- nucliadb_utils/helpers.py +1 -1
- nucliadb_utils/nats.py +15 -16
- nucliadb_utils/nuclia_usage/utils/kb_usage_report.py +4 -5
- nucliadb_utils/run.py +1 -1
- nucliadb_utils/settings.py +40 -41
- nucliadb_utils/signals.py +3 -3
- nucliadb_utils/storages/azure.py +18 -18
- nucliadb_utils/storages/gcs.py +22 -21
- nucliadb_utils/storages/local.py +8 -8
- nucliadb_utils/storages/nuclia.py +1 -2
- nucliadb_utils/storages/object_store.py +6 -6
- nucliadb_utils/storages/s3.py +22 -22
- nucliadb_utils/storages/settings.py +7 -8
- nucliadb_utils/storages/storage.py +29 -45
- nucliadb_utils/storages/utils.py +2 -3
- nucliadb_utils/store.py +2 -2
- nucliadb_utils/tests/asyncbenchmark.py +8 -10
- nucliadb_utils/tests/azure.py +2 -1
- nucliadb_utils/tests/fixtures.py +3 -2
- nucliadb_utils/tests/gcs.py +3 -2
- nucliadb_utils/tests/local.py +2 -1
- nucliadb_utils/tests/nats.py +1 -1
- nucliadb_utils/tests/s3.py +2 -1
- nucliadb_utils/transaction.py +16 -18
- nucliadb_utils/utilities.py +22 -24
- {nucliadb_utils-6.9.6.post5473.dist-info → nucliadb_utils-6.10.0.post5732.dist-info}/METADATA +5 -5
- nucliadb_utils-6.10.0.post5732.dist-info/RECORD +59 -0
- nucliadb_utils-6.9.6.post5473.dist-info/RECORD +0 -59
- {nucliadb_utils-6.9.6.post5473.dist-info → nucliadb_utils-6.10.0.post5732.dist-info}/WHEEL +0 -0
- {nucliadb_utils-6.9.6.post5473.dist-info → nucliadb_utils-6.10.0.post5732.dist-info}/top_level.txt +0 -0
nucliadb_utils/asyncio_utils.py
CHANGED
|
@@ -19,7 +19,7 @@
|
|
|
19
19
|
#
|
|
20
20
|
import asyncio
|
|
21
21
|
from collections.abc import Coroutine
|
|
22
|
-
from typing import Any
|
|
22
|
+
from typing import Any
|
|
23
23
|
|
|
24
24
|
from nucliadb_utils import logger
|
|
25
25
|
|
|
@@ -30,7 +30,7 @@ class ConcurrentRunner:
|
|
|
30
30
|
Returns the results of the coroutines in the order they were scheduled.
|
|
31
31
|
"""
|
|
32
32
|
|
|
33
|
-
def __init__(self, max_tasks:
|
|
33
|
+
def __init__(self, max_tasks: int | None = None):
|
|
34
34
|
self._tasks: list[asyncio.Task] = []
|
|
35
35
|
self.max_tasks = asyncio.Semaphore(max_tasks) if max_tasks is not None else None
|
|
36
36
|
|
|
@@ -62,7 +62,7 @@ class ConcurrentRunner:
|
|
|
62
62
|
return results
|
|
63
63
|
|
|
64
64
|
|
|
65
|
-
async def run_concurrently(tasks: list[Coroutine], max_concurrent:
|
|
65
|
+
async def run_concurrently(tasks: list[Coroutine], max_concurrent: int | None = None) -> list[Any]:
|
|
66
66
|
"""
|
|
67
67
|
Runs a list of coroutines concurrently, with a maximum number of tasks running.
|
|
68
68
|
Returns the results of the coroutines in the order they were scheduled.
|
nucliadb_utils/audit/audit.py
CHANGED
|
@@ -17,7 +17,6 @@
|
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
|
-
from typing import List, Optional
|
|
21
20
|
|
|
22
21
|
from google.protobuf.timestamp_pb2 import Timestamp
|
|
23
22
|
from nidx_protos.nodereader_pb2 import SearchRequest
|
|
@@ -34,14 +33,14 @@ class AuditStorage:
|
|
|
34
33
|
*,
|
|
35
34
|
kbid: str,
|
|
36
35
|
audit_type: AuditRequest.AuditType.Value, # type: ignore
|
|
37
|
-
when:
|
|
38
|
-
user:
|
|
39
|
-
origin:
|
|
40
|
-
rid:
|
|
41
|
-
field_metadata:
|
|
42
|
-
audit_fields:
|
|
36
|
+
when: Timestamp | None = None,
|
|
37
|
+
user: str | None = None,
|
|
38
|
+
origin: str | None = None,
|
|
39
|
+
rid: str | None = None,
|
|
40
|
+
field_metadata: list[FieldID] | None = None,
|
|
41
|
+
audit_fields: list[AuditField] | None = None,
|
|
43
42
|
):
|
|
44
|
-
raise NotImplementedError
|
|
43
|
+
raise NotImplementedError()
|
|
45
44
|
|
|
46
45
|
async def initialize(self):
|
|
47
46
|
pass
|
|
@@ -56,10 +55,10 @@ class AuditStorage:
|
|
|
56
55
|
user: str,
|
|
57
56
|
origin: str,
|
|
58
57
|
):
|
|
59
|
-
raise NotImplementedError
|
|
58
|
+
raise NotImplementedError()
|
|
60
59
|
|
|
61
60
|
def send(self, msg: AuditRequest):
|
|
62
|
-
raise NotImplementedError
|
|
61
|
+
raise NotImplementedError()
|
|
63
62
|
|
|
64
63
|
def search(
|
|
65
64
|
self,
|
|
@@ -70,9 +69,9 @@ class AuditStorage:
|
|
|
70
69
|
search: SearchRequest,
|
|
71
70
|
timeit: float,
|
|
72
71
|
resources: int,
|
|
73
|
-
retrieval_rephrased_question:
|
|
72
|
+
retrieval_rephrased_question: str | None = None,
|
|
74
73
|
):
|
|
75
|
-
raise NotImplementedError
|
|
74
|
+
raise NotImplementedError()
|
|
76
75
|
|
|
77
76
|
def chat(
|
|
78
77
|
self,
|
|
@@ -81,24 +80,35 @@ class AuditStorage:
|
|
|
81
80
|
client: int,
|
|
82
81
|
origin: str,
|
|
83
82
|
question: str,
|
|
84
|
-
rephrased_question:
|
|
85
|
-
retrieval_rephrased_question:
|
|
86
|
-
chat_context:
|
|
87
|
-
retrieved_context:
|
|
88
|
-
answer:
|
|
89
|
-
reasoning:
|
|
90
|
-
learning_id:
|
|
83
|
+
rephrased_question: str | None,
|
|
84
|
+
retrieval_rephrased_question: str | None,
|
|
85
|
+
chat_context: list[ChatContext],
|
|
86
|
+
retrieved_context: list[RetrievedContext],
|
|
87
|
+
answer: str | None,
|
|
88
|
+
reasoning: str | None,
|
|
89
|
+
learning_id: str | None,
|
|
91
90
|
status_code: int,
|
|
92
|
-
model:
|
|
93
|
-
rephrase_time:
|
|
94
|
-
generative_answer_time:
|
|
95
|
-
generative_answer_first_chunk_time:
|
|
96
|
-
generative_reasoning_first_chunk_time:
|
|
91
|
+
model: str | None,
|
|
92
|
+
rephrase_time: float | None = None,
|
|
93
|
+
generative_answer_time: float | None = None,
|
|
94
|
+
generative_answer_first_chunk_time: float | None = None,
|
|
95
|
+
generative_reasoning_first_chunk_time: float | None = None,
|
|
97
96
|
):
|
|
98
|
-
raise NotImplementedError
|
|
97
|
+
raise NotImplementedError()
|
|
98
|
+
|
|
99
|
+
def retrieve(
|
|
100
|
+
self,
|
|
101
|
+
kbid: str,
|
|
102
|
+
user: str,
|
|
103
|
+
client: int,
|
|
104
|
+
origin: str,
|
|
105
|
+
retrieval_time: float,
|
|
106
|
+
):
|
|
107
|
+
# TODO(decoupled-ask): implement audit for /retrieve
|
|
108
|
+
...
|
|
99
109
|
|
|
100
110
|
def report_storage(self, kbid: str, paragraphs: int, fields: int, bytes: int):
|
|
101
|
-
raise NotImplementedError
|
|
111
|
+
raise NotImplementedError()
|
|
102
112
|
|
|
103
113
|
def report_resources(
|
|
104
114
|
self,
|
|
@@ -106,10 +116,10 @@ class AuditStorage:
|
|
|
106
116
|
kbid: str,
|
|
107
117
|
resources: int,
|
|
108
118
|
):
|
|
109
|
-
raise NotImplementedError
|
|
119
|
+
raise NotImplementedError()
|
|
110
120
|
|
|
111
121
|
def delete_kb(self, kbid: str):
|
|
112
|
-
raise NotImplementedError
|
|
122
|
+
raise NotImplementedError()
|
|
113
123
|
|
|
114
124
|
def feedback(
|
|
115
125
|
self,
|
|
@@ -120,7 +130,7 @@ class AuditStorage:
|
|
|
120
130
|
learning_id: str,
|
|
121
131
|
good: bool,
|
|
122
132
|
task: int,
|
|
123
|
-
feedback:
|
|
124
|
-
text_block_id:
|
|
133
|
+
feedback: str | None,
|
|
134
|
+
text_block_id: str | None,
|
|
125
135
|
):
|
|
126
|
-
raise NotImplementedError
|
|
136
|
+
raise NotImplementedError()
|
nucliadb_utils/audit/basic.py
CHANGED
|
@@ -17,7 +17,6 @@
|
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
|
-
from typing import List, Optional
|
|
21
20
|
|
|
22
21
|
from google.protobuf.timestamp_pb2 import Timestamp
|
|
23
22
|
from nidx_protos.nodereader_pb2 import SearchRequest
|
|
@@ -34,19 +33,19 @@ class BasicAuditStorage(AuditStorage):
|
|
|
34
33
|
self.initialized = True
|
|
35
34
|
|
|
36
35
|
def message_to_str(self, message: BrokerMessage) -> str:
|
|
37
|
-
return f"{message.type}+{message.multiid}+{message.audit.user}+{message.kbid}+{message.uuid}+{message.audit.when.ToJsonString()}+{message.audit.origin}+{message.audit.source}"
|
|
36
|
+
return f"{message.type}+{message.multiid}+{message.audit.user}+{message.kbid}+{message.uuid}+{message.audit.when.ToJsonString()}+{message.audit.origin}+{message.audit.source}"
|
|
38
37
|
|
|
39
38
|
async def report_and_send(
|
|
40
39
|
self,
|
|
41
40
|
*,
|
|
42
41
|
kbid: str,
|
|
43
42
|
audit_type: AuditRequest.AuditType.Value, # type: ignore
|
|
44
|
-
when:
|
|
45
|
-
user:
|
|
46
|
-
origin:
|
|
47
|
-
rid:
|
|
48
|
-
field_metadata:
|
|
49
|
-
audit_fields:
|
|
43
|
+
when: Timestamp | None = None,
|
|
44
|
+
user: str | None = None,
|
|
45
|
+
origin: str | None = None,
|
|
46
|
+
rid: str | None = None,
|
|
47
|
+
field_metadata: list[FieldID] | None = None,
|
|
48
|
+
audit_fields: list[AuditField] | None = None,
|
|
50
49
|
):
|
|
51
50
|
logger.debug(f"AUDIT {audit_type} {kbid} {user} {origin} {rid} {audit_fields}")
|
|
52
51
|
|
|
@@ -71,7 +70,7 @@ class BasicAuditStorage(AuditStorage):
|
|
|
71
70
|
search: SearchRequest,
|
|
72
71
|
timeit: float,
|
|
73
72
|
resources: int,
|
|
74
|
-
retrieval_rephrased_question:
|
|
73
|
+
retrieval_rephrased_question: str | None = None,
|
|
75
74
|
):
|
|
76
75
|
logger.debug(f"SEARCH {kbid} {user} {origin} ''{search}'' {timeit} {resources}")
|
|
77
76
|
|
|
@@ -82,19 +81,19 @@ class BasicAuditStorage(AuditStorage):
|
|
|
82
81
|
client_type: int,
|
|
83
82
|
origin: str,
|
|
84
83
|
question: str,
|
|
85
|
-
rephrased_question:
|
|
86
|
-
retrieval_rephrased_question:
|
|
87
|
-
chat_context:
|
|
88
|
-
retrieved_context:
|
|
89
|
-
answer:
|
|
90
|
-
reasoning:
|
|
91
|
-
learning_id:
|
|
84
|
+
rephrased_question: str | None,
|
|
85
|
+
retrieval_rephrased_question: str | None,
|
|
86
|
+
chat_context: list[ChatContext],
|
|
87
|
+
retrieved_context: list[RetrievedContext],
|
|
88
|
+
answer: str | None,
|
|
89
|
+
reasoning: str | None,
|
|
90
|
+
learning_id: str | None,
|
|
92
91
|
status_code: int,
|
|
93
|
-
model:
|
|
94
|
-
rephrase_time:
|
|
95
|
-
generative_answer_time:
|
|
96
|
-
generative_answer_first_chunk_time:
|
|
97
|
-
generative_reasoning_first_chunk_time:
|
|
92
|
+
model: str | None,
|
|
93
|
+
rephrase_time: float | None = None,
|
|
94
|
+
generative_answer_time: float | None = None,
|
|
95
|
+
generative_answer_first_chunk_time: float | None = None,
|
|
96
|
+
generative_reasoning_first_chunk_time: float | None = None,
|
|
98
97
|
):
|
|
99
98
|
logger.debug(f"CHAT {kbid} {user} {origin}")
|
|
100
99
|
|
|
@@ -121,7 +120,7 @@ class BasicAuditStorage(AuditStorage):
|
|
|
121
120
|
learning_id: str,
|
|
122
121
|
good: bool,
|
|
123
122
|
task: int,
|
|
124
|
-
feedback:
|
|
125
|
-
text_block_id:
|
|
123
|
+
feedback: str | None,
|
|
124
|
+
text_block_id: str | None,
|
|
126
125
|
):
|
|
127
126
|
logger.debug(f"FEEDBACK {kbid} {user} {client_type} {origin}")
|
nucliadb_utils/audit/stream.py
CHANGED
|
@@ -21,8 +21,8 @@ import asyncio
|
|
|
21
21
|
import contextvars
|
|
22
22
|
import json
|
|
23
23
|
import time
|
|
24
|
+
from collections.abc import Callable
|
|
24
25
|
from datetime import datetime, timezone
|
|
25
|
-
from typing import Callable, List, Optional
|
|
26
26
|
|
|
27
27
|
import backoff
|
|
28
28
|
import mmh3
|
|
@@ -71,17 +71,17 @@ class RequestContext:
|
|
|
71
71
|
self.path: str = ""
|
|
72
72
|
|
|
73
73
|
|
|
74
|
-
request_context_var = contextvars.ContextVar[
|
|
74
|
+
request_context_var = contextvars.ContextVar[RequestContext | None]("request_context", default=None)
|
|
75
75
|
|
|
76
76
|
|
|
77
|
-
def get_trace_id() ->
|
|
77
|
+
def get_trace_id() -> str | None:
|
|
78
78
|
span = get_current_span()
|
|
79
79
|
if span is None:
|
|
80
80
|
return None
|
|
81
81
|
return format_trace_id(span.get_span_context().trace_id)
|
|
82
82
|
|
|
83
83
|
|
|
84
|
-
def get_request_context() ->
|
|
84
|
+
def get_request_context() -> RequestContext | None:
|
|
85
85
|
return request_context_var.get()
|
|
86
86
|
|
|
87
87
|
|
|
@@ -103,7 +103,7 @@ def fill_audit_search_request(audit: AuditSearchRequest, request: SearchRequest)
|
|
|
103
103
|
|
|
104
104
|
|
|
105
105
|
class AuditMiddleware(BaseHTTPMiddleware):
|
|
106
|
-
def __init__(self, app: ASGIApp, audit_utility_getter: Callable[[],
|
|
106
|
+
def __init__(self, app: ASGIApp, audit_utility_getter: Callable[[], AuditStorage | None]) -> None:
|
|
107
107
|
self.audit_utility_getter = audit_utility_getter
|
|
108
108
|
super().__init__(app)
|
|
109
109
|
|
|
@@ -155,17 +155,17 @@ KB_USAGE_STREAM_SUBJECT = "kb-usage.nuclia_db"
|
|
|
155
155
|
|
|
156
156
|
|
|
157
157
|
class StreamAuditStorage(AuditStorage):
|
|
158
|
-
task:
|
|
158
|
+
task: asyncio.Task | None
|
|
159
159
|
initialized: bool
|
|
160
160
|
queue: asyncio.Queue
|
|
161
161
|
|
|
162
162
|
def __init__(
|
|
163
163
|
self,
|
|
164
|
-
nats_servers:
|
|
164
|
+
nats_servers: list[str],
|
|
165
165
|
nats_target: str,
|
|
166
166
|
partitions: int,
|
|
167
167
|
seed: int,
|
|
168
|
-
nats_creds:
|
|
168
|
+
nats_creds: str | None = None,
|
|
169
169
|
service: str = "nucliadb.audit",
|
|
170
170
|
):
|
|
171
171
|
self.nats_servers = nats_servers
|
|
@@ -186,10 +186,10 @@ class StreamAuditStorage(AuditStorage):
|
|
|
186
186
|
|
|
187
187
|
async def reconnected_cb(self):
|
|
188
188
|
# See who we are connected to on reconnect.
|
|
189
|
-
logger.info("Got reconnected to NATS {
|
|
189
|
+
logger.info(f"Got reconnected to NATS {self.nc.connected_url}")
|
|
190
190
|
|
|
191
191
|
async def error_cb(self, e):
|
|
192
|
-
logger.error("There was an error connecting to NATS audit: {}"
|
|
192
|
+
logger.error(f"There was an error connecting to NATS audit: {e}", exc_info=True)
|
|
193
193
|
|
|
194
194
|
async def closed_cb(self):
|
|
195
195
|
logger.info("Connection is closed on NATS")
|
|
@@ -269,12 +269,12 @@ class StreamAuditStorage(AuditStorage):
|
|
|
269
269
|
*,
|
|
270
270
|
kbid: str,
|
|
271
271
|
audit_type: AuditRequest.AuditType.Value, # type: ignore
|
|
272
|
-
when:
|
|
273
|
-
user:
|
|
274
|
-
origin:
|
|
275
|
-
rid:
|
|
276
|
-
field_metadata:
|
|
277
|
-
audit_fields:
|
|
272
|
+
when: Timestamp | None = None,
|
|
273
|
+
user: str | None = None,
|
|
274
|
+
origin: str | None = None,
|
|
275
|
+
rid: str | None = None,
|
|
276
|
+
field_metadata: list[FieldID] | None = None,
|
|
277
|
+
audit_fields: list[AuditField] | None = None,
|
|
278
278
|
):
|
|
279
279
|
auditrequest = AuditRequest()
|
|
280
280
|
|
|
@@ -369,7 +369,7 @@ class StreamAuditStorage(AuditStorage):
|
|
|
369
369
|
search: SearchRequest,
|
|
370
370
|
timeit: float,
|
|
371
371
|
resources: int,
|
|
372
|
-
retrieval_rephrased_question:
|
|
372
|
+
retrieval_rephrased_question: str | None = None,
|
|
373
373
|
):
|
|
374
374
|
context = get_request_context()
|
|
375
375
|
if context is None:
|
|
@@ -420,19 +420,19 @@ class StreamAuditStorage(AuditStorage):
|
|
|
420
420
|
client_type: int,
|
|
421
421
|
origin: str,
|
|
422
422
|
question: str,
|
|
423
|
-
rephrased_question:
|
|
424
|
-
retrieval_rephrased_question:
|
|
425
|
-
chat_context:
|
|
426
|
-
retrieved_context:
|
|
427
|
-
answer:
|
|
428
|
-
reasoning:
|
|
429
|
-
learning_id:
|
|
423
|
+
rephrased_question: str | None,
|
|
424
|
+
retrieval_rephrased_question: str | None,
|
|
425
|
+
chat_context: list[ChatContext],
|
|
426
|
+
retrieved_context: list[RetrievedContext],
|
|
427
|
+
answer: str | None,
|
|
428
|
+
reasoning: str | None,
|
|
429
|
+
learning_id: str | None,
|
|
430
430
|
status_code: int,
|
|
431
|
-
model:
|
|
432
|
-
rephrase_time:
|
|
433
|
-
generative_answer_time:
|
|
434
|
-
generative_answer_first_chunk_time:
|
|
435
|
-
generative_reasoning_first_chunk_time:
|
|
431
|
+
model: str | None,
|
|
432
|
+
rephrase_time: float | None = None,
|
|
433
|
+
generative_answer_time: float | None = None,
|
|
434
|
+
generative_answer_first_chunk_time: float | None = None,
|
|
435
|
+
generative_reasoning_first_chunk_time: float | None = None,
|
|
436
436
|
):
|
|
437
437
|
rcontext = get_request_context()
|
|
438
438
|
if rcontext is None:
|
|
@@ -482,8 +482,8 @@ class StreamAuditStorage(AuditStorage):
|
|
|
482
482
|
learning_id: str,
|
|
483
483
|
good: bool,
|
|
484
484
|
task: int,
|
|
485
|
-
feedback:
|
|
486
|
-
text_block_id:
|
|
485
|
+
feedback: str | None,
|
|
486
|
+
text_block_id: str | None,
|
|
487
487
|
):
|
|
488
488
|
rcontext = get_request_context()
|
|
489
489
|
if rcontext is None:
|
nucliadb_utils/authentication.py
CHANGED
|
@@ -17,12 +17,10 @@
|
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
|
|
20
|
-
import asyncio
|
|
21
20
|
import functools
|
|
22
21
|
import inspect
|
|
23
22
|
import typing
|
|
24
23
|
from enum import Enum
|
|
25
|
-
from typing import Optional, Tuple
|
|
26
24
|
|
|
27
25
|
from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser
|
|
28
26
|
from starlette.exceptions import HTTPException
|
|
@@ -32,7 +30,7 @@ from starlette.websockets import WebSocket
|
|
|
32
30
|
|
|
33
31
|
|
|
34
32
|
class NucliaUser(BaseUser):
|
|
35
|
-
def __init__(self, username: str, security_groups:
|
|
33
|
+
def __init__(self, username: str, security_groups: list[str] | None = None) -> None:
|
|
36
34
|
self.username = username
|
|
37
35
|
self._security_groups = security_groups
|
|
38
36
|
|
|
@@ -45,7 +43,7 @@ class NucliaUser(BaseUser):
|
|
|
45
43
|
return self.username
|
|
46
44
|
|
|
47
45
|
@property
|
|
48
|
-
def security_groups(self) ->
|
|
46
|
+
def security_groups(self) -> list[str] | None:
|
|
49
47
|
return self._security_groups
|
|
50
48
|
|
|
51
49
|
|
|
@@ -63,7 +61,7 @@ class NucliaCloudAuthenticationBackend(AuthenticationBackend):
|
|
|
63
61
|
self.user_header = user_header
|
|
64
62
|
self.security_groups_header = security_groups_header
|
|
65
63
|
|
|
66
|
-
async def authenticate(self, request: HTTPConnection) ->
|
|
64
|
+
async def authenticate(self, request: HTTPConnection) -> tuple[AuthCredentials, BaseUser] | None:
|
|
67
65
|
if self.roles_header not in request.headers:
|
|
68
66
|
return None
|
|
69
67
|
else:
|
|
@@ -74,9 +72,9 @@ class NucliaCloudAuthenticationBackend(AuthenticationBackend):
|
|
|
74
72
|
if self.user_header in request.headers:
|
|
75
73
|
user = request.headers[self.user_header]
|
|
76
74
|
|
|
77
|
-
raw_security_groups:
|
|
75
|
+
raw_security_groups: str | None = request.headers.get(self.security_groups_header)
|
|
78
76
|
|
|
79
|
-
security_groups:
|
|
77
|
+
security_groups: list[str] | None = None
|
|
80
78
|
if raw_security_groups is not None:
|
|
81
79
|
security_groups = raw_security_groups.split(";")
|
|
82
80
|
|
|
@@ -99,9 +97,9 @@ def has_required_scope(conn: HTTPConnection, scopes: typing.Sequence[str]) -> bo
|
|
|
99
97
|
|
|
100
98
|
|
|
101
99
|
def requires(
|
|
102
|
-
scopes:
|
|
100
|
+
scopes: str | typing.Sequence[str],
|
|
103
101
|
status_code: int = 403,
|
|
104
|
-
redirect:
|
|
102
|
+
redirect: str | None = None,
|
|
105
103
|
) -> typing.Callable:
|
|
106
104
|
# As a fastapi requirement, custom Enum classes have to inherit also from
|
|
107
105
|
# string, so we MUST check for Enum before str
|
|
@@ -137,7 +135,7 @@ def requires(
|
|
|
137
135
|
|
|
138
136
|
return websocket_wrapper
|
|
139
137
|
|
|
140
|
-
elif
|
|
138
|
+
elif inspect.iscoroutinefunction(func):
|
|
141
139
|
# Handle async request/response functions.
|
|
142
140
|
@functools.wraps(func)
|
|
143
141
|
async def async_wrapper(*args: typing.Any, **kwargs: typing.Any) -> Response:
|
nucliadb_utils/cache/nats.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
# -*- coding: utf-8 -*-
|
|
2
1
|
# Copyright (C) 2021 Bosutech XXI S.L.
|
|
3
2
|
#
|
|
4
3
|
# nucliadb is offered under the AGPL v3.0 and as commercial software.
|
|
@@ -23,7 +22,6 @@ import functools
|
|
|
23
22
|
import os
|
|
24
23
|
import uuid
|
|
25
24
|
from inspect import iscoroutinefunction
|
|
26
|
-
from typing import Dict, List, Optional, Union
|
|
27
25
|
|
|
28
26
|
import nats
|
|
29
27
|
import nats.errors
|
|
@@ -54,16 +52,16 @@ async def wait_for_it(future: asyncio.Future, msg):
|
|
|
54
52
|
class NatsPubsub(PubSubDriver[Msg]):
|
|
55
53
|
_jetstream = None
|
|
56
54
|
_jsm = None
|
|
57
|
-
_subscriptions:
|
|
55
|
+
_subscriptions: dict[str, Subscription]
|
|
58
56
|
async_callback = True
|
|
59
57
|
|
|
60
58
|
def __init__(
|
|
61
59
|
self,
|
|
62
60
|
thread: bool = False,
|
|
63
|
-
name:
|
|
61
|
+
name: str | None = "natsutility",
|
|
64
62
|
timeout: float = 2.0,
|
|
65
|
-
hosts:
|
|
66
|
-
user_credentials_file:
|
|
63
|
+
hosts: list[str] | None = None,
|
|
64
|
+
user_credentials_file: str | None = None,
|
|
67
65
|
):
|
|
68
66
|
self._hosts = hosts or []
|
|
69
67
|
self._timeout = timeout
|
|
@@ -73,11 +71,11 @@ class NatsPubsub(PubSubDriver[Msg]):
|
|
|
73
71
|
self._uuid = os.environ.get("HOSTNAME", uuid.uuid4().hex)
|
|
74
72
|
self.initialized = False
|
|
75
73
|
self.lock = asyncio.Lock()
|
|
76
|
-
self.nc:
|
|
74
|
+
self.nc: Client | NatsClientTelemetry | None = None
|
|
77
75
|
self.user_credentials_file = user_credentials_file
|
|
78
76
|
|
|
79
77
|
@property
|
|
80
|
-
def jetstream(self) ->
|
|
78
|
+
def jetstream(self) -> JetStreamContext | JetStreamContextTelemetry:
|
|
81
79
|
if self.nc is None:
|
|
82
80
|
raise AttributeError("NC not initialized")
|
|
83
81
|
if self._jetstream is None:
|
|
@@ -154,10 +152,10 @@ class NatsPubsub(PubSubDriver[Msg]):
|
|
|
154
152
|
|
|
155
153
|
async def reconnected_cb(self):
|
|
156
154
|
# See who we are connected to on reconnect.
|
|
157
|
-
logger.info("Got reconnected NATS to {
|
|
155
|
+
logger.info(f"Got reconnected NATS to {self.nc.connected_url.netloc}")
|
|
158
156
|
|
|
159
157
|
async def error_cb(self, e):
|
|
160
|
-
logger.info("There was an error connecting to NATS {}"
|
|
158
|
+
logger.info(f"There was an error connecting to NATS {e}", exc_info=True)
|
|
161
159
|
|
|
162
160
|
async def closed_cb(self):
|
|
163
161
|
logger.info("Connection is closed to NATS")
|
|
@@ -173,7 +171,7 @@ class NatsPubsub(PubSubDriver[Msg]):
|
|
|
173
171
|
else:
|
|
174
172
|
raise ErrConnectionClosed("Could not subscribe")
|
|
175
173
|
|
|
176
|
-
async def subscribe(self, handler: Callback, key, group="", subscription_id:
|
|
174
|
+
async def subscribe(self, handler: Callback, key, group="", subscription_id: str | None = None):
|
|
177
175
|
if subscription_id is None:
|
|
178
176
|
subscription_id = key
|
|
179
177
|
|
|
@@ -189,7 +187,7 @@ class NatsPubsub(PubSubDriver[Msg]):
|
|
|
189
187
|
else:
|
|
190
188
|
raise ErrConnectionClosed("Could not subscribe")
|
|
191
189
|
|
|
192
|
-
async def unsubscribe(self, key: str, subscription_id:
|
|
190
|
+
async def unsubscribe(self, key: str, subscription_id: str | None = None):
|
|
193
191
|
if subscription_id is None:
|
|
194
192
|
subscription_id = key
|
|
195
193
|
|
nucliadb_utils/cache/pubsub.py
CHANGED
|
@@ -17,7 +17,8 @@
|
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
#
|
|
20
|
-
from
|
|
20
|
+
from collections.abc import Awaitable, Callable
|
|
21
|
+
from typing import Generic, TypeVar
|
|
21
22
|
|
|
22
23
|
T = TypeVar("T")
|
|
23
24
|
|
|
@@ -37,15 +38,15 @@ class PubSubDriver(Generic[T]):
|
|
|
37
38
|
async def publish(self, channel_name: str, data: bytes):
|
|
38
39
|
raise NotImplementedError()
|
|
39
40
|
|
|
40
|
-
async def unsubscribe(self, key: str, subscription_id:
|
|
41
|
+
async def unsubscribe(self, key: str, subscription_id: str | None = None):
|
|
41
42
|
raise NotImplementedError()
|
|
42
43
|
|
|
43
44
|
async def subscribe(
|
|
44
45
|
self,
|
|
45
46
|
handler: Callback,
|
|
46
47
|
key: str,
|
|
47
|
-
group:
|
|
48
|
-
subscription_id:
|
|
48
|
+
group: str | None = None,
|
|
49
|
+
subscription_id: str | None = None,
|
|
49
50
|
):
|
|
50
51
|
raise NotImplementedError()
|
|
51
52
|
|
nucliadb_utils/cache/settings.py
CHANGED
|
@@ -17,14 +17,13 @@
|
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
|
|
20
|
-
from typing import List, Optional
|
|
21
20
|
|
|
22
21
|
from pydantic_settings import BaseSettings
|
|
23
22
|
|
|
24
23
|
|
|
25
24
|
class Settings(BaseSettings):
|
|
26
|
-
cache_pubsub_nats_url:
|
|
27
|
-
cache_pubsub_nats_auth:
|
|
25
|
+
cache_pubsub_nats_url: list[str] = []
|
|
26
|
+
cache_pubsub_nats_auth: str | None = None
|
|
28
27
|
|
|
29
28
|
|
|
30
29
|
settings = Settings()
|
nucliadb_utils/debug.py
CHANGED
|
@@ -54,7 +54,7 @@ def display_top(snapshot, key_type="lineno", limit=10): # pragma: no cover
|
|
|
54
54
|
print(f"Top {limit} lines")
|
|
55
55
|
for index, stat in enumerate(top_stats[:limit], 1):
|
|
56
56
|
frame = stat.traceback[0]
|
|
57
|
-
print("
|
|
57
|
+
print(f"#{index}: {frame.filename}:{frame.lineno}: {stat.size / 1024:.1f} KiB")
|
|
58
58
|
line = linecache.getline(frame.filename, frame.lineno).strip()
|
|
59
59
|
if line:
|
|
60
60
|
print(" %s" % line)
|
|
@@ -62,6 +62,6 @@ def display_top(snapshot, key_type="lineno", limit=10): # pragma: no cover
|
|
|
62
62
|
other = top_stats[limit:]
|
|
63
63
|
if other:
|
|
64
64
|
size = sum(stat.size for stat in other)
|
|
65
|
-
print("
|
|
65
|
+
print(f"{len(other)} other: {size / 1024:.1f} KiB")
|
|
66
66
|
total = sum(stat.size for stat in top_stats)
|
|
67
67
|
print("Total allocated size: %.1f KiB" % (total / 1024))
|
|
@@ -16,14 +16,13 @@
|
|
|
16
16
|
#
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
|
-
from typing import Optional
|
|
20
19
|
|
|
21
20
|
from pydantic import Field
|
|
22
21
|
from pydantic_settings import BaseSettings
|
|
23
22
|
|
|
24
23
|
|
|
25
24
|
class EncryptionSettings(BaseSettings):
|
|
26
|
-
encryption_secret_key:
|
|
25
|
+
encryption_secret_key: str | None = Field(
|
|
27
26
|
default=None,
|
|
28
27
|
title="Encryption Secret Key",
|
|
29
28
|
description="""Secret key used for encryption and decryption of sensitive data in the database.
|
|
@@ -17,7 +17,6 @@
|
|
|
17
17
|
# You should have received a copy of the GNU Affero General Public License
|
|
18
18
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
19
19
|
|
|
20
|
-
from typing import Union
|
|
21
20
|
|
|
22
21
|
from fastapi.applications import FastAPI
|
|
23
22
|
from starlette.routing import Mount
|
|
@@ -34,7 +33,7 @@ def format_scopes(scope_list):
|
|
|
34
33
|
return "\n".join(f"- `{scope}`" for scope in scope_list)
|
|
35
34
|
|
|
36
35
|
|
|
37
|
-
def extend_openapi(app:
|
|
36
|
+
def extend_openapi(app: FastAPI | Mount): # pragma: no cover
|
|
38
37
|
for route in app.routes:
|
|
39
38
|
# mypy complains about BaseRoute not having endpoint and
|
|
40
39
|
# description attributes, but routes passed here always have
|