nucliadb-utils 5.0.0.post866__py3-none-any.whl → 5.0.0.post899__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -23,7 +23,6 @@ from google.protobuf.timestamp_pb2 import Timestamp
23
23
 
24
24
  from nucliadb_protos.audit_pb2 import (
25
25
  AuditField,
26
- AuditKBCounter,
27
26
  AuditRequest,
28
27
  ChatContext,
29
28
  )
@@ -32,7 +31,9 @@ from nucliadb_protos.resources_pb2 import FieldID
32
31
 
33
32
 
34
33
  class AuditStorage:
35
- async def report(
34
+ initialized: bool = False
35
+
36
+ def report_and_send(
36
37
  self,
37
38
  *,
38
39
  kbid: str,
@@ -43,15 +44,6 @@ class AuditStorage:
43
44
  rid: Optional[str] = None,
44
45
  field_metadata: Optional[List[FieldID]] = None,
45
46
  audit_fields: Optional[List[AuditField]] = None,
46
- kb_counter: Optional[AuditKBCounter] = None,
47
- ):
48
- raise NotImplementedError
49
-
50
- def report_resources(
51
- self,
52
- *,
53
- kbid: str,
54
- resources: int,
55
47
  ):
56
48
  raise NotImplementedError
57
49
 
@@ -61,45 +53,43 @@ class AuditStorage:
61
53
  async def finalize(self):
62
54
  pass
63
55
 
64
- async def visited(self, kbid: str, uuid: str, user: str, origin: str):
65
- raise NotImplementedError
66
-
67
- async def search(
56
+ def visited(
68
57
  self,
69
58
  kbid: str,
59
+ uuid: str,
70
60
  user: str,
71
- client: int,
72
61
  origin: str,
73
- search: SearchRequest,
74
- timeit: float,
75
- resources: int,
76
62
  ):
77
63
  raise NotImplementedError
78
64
 
79
- async def suggest(
65
+ def send(self, msg: AuditRequest):
66
+ raise NotImplementedError
67
+
68
+ def search(
80
69
  self,
81
70
  kbid: str,
82
71
  user: str,
83
72
  client: int,
84
73
  origin: str,
74
+ search: SearchRequest,
85
75
  timeit: float,
76
+ resources: int,
86
77
  ):
87
78
  raise NotImplementedError
88
79
 
89
- async def chat(
80
+ def chat(
90
81
  self,
91
82
  kbid: str,
92
83
  user: str,
93
84
  client: int,
94
85
  origin: str,
95
- timeit: float,
96
86
  question: str,
97
87
  rephrased_question: Optional[str],
98
88
  context: List[ChatContext],
99
89
  answer: Optional[str],
100
90
  learning_id: str,
91
+ rephrase_time: Optional[float] = None,
92
+ generative_answer_time: Optional[float] = None,
93
+ generative_answer_first_chunk_time: Optional[float] = None,
101
94
  ):
102
95
  raise NotImplementedError
103
-
104
- async def delete_kb(self, kbid):
105
- raise NotImplementedError
@@ -23,7 +23,6 @@ from google.protobuf.timestamp_pb2 import Timestamp
23
23
 
24
24
  from nucliadb_protos.audit_pb2 import (
25
25
  AuditField,
26
- AuditKBCounter,
27
26
  AuditRequest,
28
27
  ChatContext,
29
28
  )
@@ -35,10 +34,13 @@ from nucliadb_utils.audit.audit import AuditStorage
35
34
 
36
35
 
37
36
  class BasicAuditStorage(AuditStorage):
37
+ def __init__(self):
38
+ self.initialized = True
39
+
38
40
  def message_to_str(self, message: BrokerMessage) -> str:
39
41
  return f"{message.type}+{message.multiid}+{message.audit.user}+{message.kbid}+{message.uuid}+{message.audit.when.ToJsonString()}+{message.audit.origin}+{message.audit.source}" # noqa
40
42
 
41
- async def report(
43
+ async def report_and_send(
42
44
  self,
43
45
  *,
44
46
  kbid: str,
@@ -49,22 +51,22 @@ class BasicAuditStorage(AuditStorage):
49
51
  rid: Optional[str] = None,
50
52
  field_metadata: Optional[List[FieldID]] = None,
51
53
  audit_fields: Optional[List[AuditField]] = None,
52
- kb_counter: Optional[AuditKBCounter] = None,
53
54
  ):
54
- logger.debug(f"AUDIT {audit_type} {kbid} {user} {origin} {rid} {audit_fields} {kb_counter}")
55
+ logger.debug(f"AUDIT {audit_type} {kbid} {user} {origin} {rid} {audit_fields}")
55
56
 
56
- def report_resources(
57
+ async def visited(
57
58
  self,
58
- *,
59
59
  kbid: str,
60
- resources: int,
60
+ uuid: str,
61
+ user: str,
62
+ origin: str,
61
63
  ):
62
- logger.debug(f"REPORT RESOURCES {kbid} {resources}")
63
-
64
- async def visited(self, kbid: str, uuid: str, user: str, origin: str):
65
64
  logger.debug(f"VISITED {kbid} {uuid} {user} {origin}")
66
65
 
67
- async def search(
66
+ def send(self, msg: AuditRequest):
67
+ logger.debug(f"sending a {msg.type} queued message")
68
+
69
+ def search(
68
70
  self,
69
71
  kbid: str,
70
72
  user: str,
@@ -76,30 +78,19 @@ class BasicAuditStorage(AuditStorage):
76
78
  ):
77
79
  logger.debug(f"SEARCH {kbid} {user} {origin} ''{search}'' {timeit} {resources}")
78
80
 
79
- async def suggest(
80
- self,
81
- kbid: str,
82
- user: str,
83
- client: int,
84
- origin: str,
85
- timeit: float,
86
- ):
87
- logger.debug(f"SUGGEST {kbid} {user} {origin} {timeit}")
88
-
89
- async def chat(
81
+ def chat(
90
82
  self,
91
83
  kbid: str,
92
84
  user: str,
93
85
  client_type: int,
94
86
  origin: str,
95
- timeit: float,
96
87
  question: str,
97
88
  rephrased_question: Optional[str],
98
89
  context: List[ChatContext],
99
90
  answer: Optional[str],
100
91
  learning_id: str,
92
+ rephrase_time: Optional[float] = None,
93
+ generative_answer_time: Optional[float] = None,
94
+ generative_answer_first_chunk_time: Optional[float] = None,
101
95
  ):
102
- logger.debug(f"CHAT {kbid} {user} {origin} {timeit}")
103
-
104
- async def delete_kb(self, kbid):
105
- logger.debug(f"KB DELETED {kbid}")
96
+ logger.debug(f"CHAT {kbid} {user} {origin}")
@@ -18,47 +18,115 @@
18
18
  # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
19
  #
20
20
  import asyncio
21
- from datetime import datetime
22
- from typing import List, Optional
21
+ import contextvars
22
+ import time
23
+ from datetime import datetime, timezone
24
+ from typing import Callable, List, Optional
23
25
 
24
26
  import backoff
25
27
  import mmh3
26
28
  import nats
29
+ from fastapi import Request
27
30
  from google.protobuf.timestamp_pb2 import Timestamp
28
31
  from opentelemetry.trace import format_trace_id, get_current_span
32
+ from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
33
+ from starlette.responses import Response, StreamingResponse
34
+ from starlette.types import ASGIApp
29
35
 
30
36
  from nucliadb_protos.audit_pb2 import (
31
37
  AuditField,
32
- AuditKBCounter,
33
38
  AuditRequest,
34
39
  ChatContext,
35
- ClientType,
36
40
  )
37
41
  from nucliadb_protos.nodereader_pb2 import SearchRequest
38
42
  from nucliadb_protos.resources_pb2 import FieldID
39
43
  from nucliadb_utils import logger
40
44
  from nucliadb_utils.audit.audit import AuditStorage
41
45
  from nucliadb_utils.nats import get_traced_jetstream
42
- from nucliadb_utils.nuclia_usage.protos.kb_usage_pb2 import (
43
- ClientType as ClientTypeKbUsage,
44
- )
45
- from nucliadb_utils.nuclia_usage.protos.kb_usage_pb2 import (
46
- KBSource,
47
- Search,
48
- SearchType,
49
- Service,
50
- Storage,
51
- )
52
- from nucliadb_utils.nuclia_usage.utils.kb_usage_report import KbUsageReportUtility
53
46
 
54
- KB_USAGE_STREAM_AUDIT = "kb-usage.nuclia_db"
47
+
48
+ class RequestContext:
49
+ def __init__(self):
50
+ self.audit_request: AuditRequest = AuditRequest()
51
+ self.start_time: float = time.monotonic()
52
+
53
+
54
+ request_context_var = contextvars.ContextVar[Optional[RequestContext]]("request_context", default=None)
55
+
56
+
57
+ def get_trace_id() -> str:
58
+ span = get_current_span()
59
+ if span is None:
60
+ return ""
61
+ return format_trace_id(span.get_span_context().trace_id)
62
+
63
+
64
+ def get_request_context() -> Optional[RequestContext]:
65
+ return request_context_var.get()
66
+
67
+
68
+ class AuditMiddleware(BaseHTTPMiddleware):
69
+ def __init__(self, app: ASGIApp, audit_utility_getter: Callable[[], Optional[AuditStorage]]) -> None:
70
+ self.audit_utility_getter = audit_utility_getter
71
+ super().__init__(app)
72
+
73
+ @property
74
+ def audit_utility(self):
75
+ return self.audit_utility_getter()
76
+
77
+ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
78
+ context = RequestContext()
79
+ token = request_context_var.set(context)
80
+ context.audit_request.time.FromDatetime(datetime.now(tz=timezone.utc))
81
+ context.audit_request.trace_id = get_trace_id()
82
+ response = await call_next(request)
83
+
84
+ if isinstance(response, StreamingResponse):
85
+ response = self.wrap_streaming_response(response, context)
86
+ else:
87
+ self.enqueue_pending(context)
88
+
89
+ request_context_var.reset(token)
90
+
91
+ return response
92
+
93
+ def enqueue_pending(self, context: RequestContext):
94
+ if context.audit_request.kbid:
95
+ # an audit request with no kbid makes no sense, we use this as an heuristic
96
+ # mark that no audit has been set during this request
97
+
98
+ context.audit_request.request_time = time.monotonic() - context.start_time
99
+ if self.audit_utility is not None:
100
+ self.audit_utility.send(context.audit_request)
101
+
102
+ def wrap_streaming_response(
103
+ self, response: StreamingResponse, context: RequestContext
104
+ ) -> StreamingResponse:
105
+ """
106
+ When dealing with streaming responses, AND if we depend on any state that only will be available once
107
+ the request is fully finished, the response we have after the dispatch call_next is not enough, as
108
+ there, no iteration of the streaming response has been done yet.
109
+
110
+ This is why we need to rewrap to be able to to the auditing at the _real_ request end without losing
111
+ any audit bits.
112
+ """
113
+ original_body_iterator = response.body_iterator
114
+
115
+ async def custom_body_iterator():
116
+ try:
117
+ async for chunk in original_body_iterator:
118
+ yield chunk
119
+ finally:
120
+ self.enqueue_pending(context)
121
+
122
+ response.body_iterator = custom_body_iterator()
123
+ return response
55
124
 
56
125
 
57
126
  class StreamAuditStorage(AuditStorage):
58
- task: Optional[asyncio.Task] = None
59
- initialized: bool = False
127
+ task: Optional[asyncio.Task]
128
+ initialized: bool
60
129
  queue: asyncio.Queue
61
- lock: asyncio.Lock
62
130
 
63
131
  def __init__(
64
132
  self,
@@ -74,9 +142,10 @@ class StreamAuditStorage(AuditStorage):
74
142
  self.nats_target = nats_target
75
143
  self.partitions = partitions
76
144
  self.seed = seed
77
- self.lock = asyncio.Lock()
78
145
  self.queue = asyncio.Queue()
79
146
  self.service = service
147
+ self.task = None
148
+ self.initialized = False
80
149
 
81
150
  def get_partition(self, kbid: str):
82
151
  return mmh3.hash(kbid, self.seed, signed=False) % self.partitions
@@ -112,16 +181,9 @@ class StreamAuditStorage(AuditStorage):
112
181
  self.js = get_traced_jetstream(self.nc, self.service)
113
182
  self.task = asyncio.create_task(self.run())
114
183
 
115
- self.kb_usage_utility = KbUsageReportUtility(
116
- nats_stream=self.js, nats_subject=KB_USAGE_STREAM_AUDIT
117
- )
118
- await self.kb_usage_utility.initialize()
119
-
120
184
  self.initialized = True
121
185
 
122
186
  async def finalize(self):
123
- await self.kb_usage_utility.finalize()
124
-
125
187
  if self.task is not None:
126
188
  self.task.cancel()
127
189
  if self.nc:
@@ -144,7 +206,7 @@ class StreamAuditStorage(AuditStorage):
144
206
  if item_dequeued:
145
207
  self.queue.task_done()
146
208
 
147
- async def send(self, message: AuditRequest):
209
+ def send(self, message: AuditRequest):
148
210
  self.queue.put_nowait(message)
149
211
 
150
212
  @backoff.on_exception(backoff.expo, (Exception,), jitter=backoff.random_jitter, max_tries=4)
@@ -163,7 +225,7 @@ class StreamAuditStorage(AuditStorage):
163
225
  )
164
226
  return res.seq
165
227
 
166
- async def report(
228
+ def report_and_send(
167
229
  self,
168
230
  *,
169
231
  kbid: str,
@@ -174,18 +236,19 @@ class StreamAuditStorage(AuditStorage):
174
236
  rid: Optional[str] = None,
175
237
  field_metadata: Optional[List[FieldID]] = None,
176
238
  audit_fields: Optional[List[AuditField]] = None,
177
- kb_counter: Optional[AuditKBCounter] = None,
178
239
  ):
179
- # Reports MODIFIED / DELETED / NEW events
180
240
  auditrequest = AuditRequest()
241
+
242
+ # Reports MODIFIED / DELETED / NEW events
243
+
244
+ auditrequest.trace_id = get_trace_id()
181
245
  auditrequest.kbid = kbid
182
246
  auditrequest.userid = user or ""
183
247
  auditrequest.rid = rid or ""
184
248
  auditrequest.origin = origin or ""
185
249
  auditrequest.type = audit_type
186
- if when is None or when.SerializeToString() == b"":
187
- auditrequest.time.FromDatetime(datetime.now())
188
- else:
250
+ # If defined, when needs to overwrite any previously set time
251
+ if not (when is None or when.SerializeToString() == b""):
189
252
  auditrequest.time.CopyFrom(when)
190
253
 
191
254
  auditrequest.field_metadata.extend(field_metadata or [])
@@ -193,66 +256,28 @@ class StreamAuditStorage(AuditStorage):
193
256
  if audit_fields:
194
257
  auditrequest.fields_audit.extend(audit_fields)
195
258
 
196
- if kb_counter:
197
- auditrequest.kb_counter.CopyFrom(kb_counter)
198
-
199
- self.kb_usage_utility.send_kb_usage(
200
- service=Service.NUCLIA_DB,
201
- account_id=None,
202
- kb_id=kbid,
203
- kb_source=KBSource.HOSTED,
204
- storage=Storage(paragraphs=kb_counter.paragraphs, fields=kb_counter.fields),
205
- )
259
+ self.send(auditrequest)
206
260
 
207
- auditrequest.trace_id = get_trace_id()
208
-
209
- await self.send(auditrequest)
210
-
211
- def report_resources(
261
+ def visited(
212
262
  self,
213
- *,
214
263
  kbid: str,
215
- resources: int,
264
+ uuid: str,
265
+ user: str,
266
+ origin: str,
267
+ send: bool = False,
216
268
  ):
217
- self.kb_usage_utility.send_kb_usage(
218
- service=Service.NUCLIA_DB,
219
- account_id=None,
220
- kb_id=kbid,
221
- kb_source=KBSource.HOSTED,
222
- storage=Storage(resources=resources),
223
- )
269
+ context = get_request_context()
270
+ if context is None:
271
+ return
272
+ auditrequest = context.audit_request
224
273
 
225
- async def visited(self, kbid: str, uuid: str, user: str, origin: str):
226
- auditrequest = AuditRequest()
227
274
  auditrequest.origin = origin
228
275
  auditrequest.userid = user
229
276
  auditrequest.rid = uuid
230
277
  auditrequest.kbid = kbid
231
278
  auditrequest.type = AuditRequest.VISITED
232
- auditrequest.time.FromDatetime(datetime.now())
233
-
234
- auditrequest.trace_id = get_trace_id()
235
-
236
- await self.send(auditrequest)
237
279
 
238
- async def delete_kb(self, kbid):
239
- # Search is a base64 encoded search
240
- auditrequest = AuditRequest()
241
- auditrequest.kbid = kbid
242
- auditrequest.type = AuditRequest.KB_DELETED
243
- auditrequest.time.FromDatetime(datetime.now())
244
- auditrequest.trace_id = get_trace_id()
245
- await self.send(auditrequest)
246
-
247
- self.kb_usage_utility.send_kb_usage(
248
- service=Service.NUCLIA_DB,
249
- account_id=None,
250
- kb_id=kbid,
251
- kb_source=KBSource.HOSTED,
252
- storage=Storage(paragraphs=0, fields=0, resources=0),
253
- )
254
-
255
- async def search(
280
+ def search(
256
281
  self,
257
282
  kbid: str,
258
283
  user: str,
@@ -262,95 +287,53 @@ class StreamAuditStorage(AuditStorage):
262
287
  timeit: float,
263
288
  resources: int,
264
289
  ):
265
- # Search is a base64 encoded search
266
- auditrequest = AuditRequest()
290
+ context = get_request_context()
291
+ if context is None:
292
+ return
293
+
294
+ auditrequest = context.audit_request
295
+
267
296
  auditrequest.origin = origin
268
297
  auditrequest.client_type = client_type # type: ignore
269
298
  auditrequest.userid = user
270
299
  auditrequest.kbid = kbid
271
300
  auditrequest.search.CopyFrom(search)
272
- auditrequest.timeit = timeit
301
+ auditrequest.retrieval_time = timeit
273
302
  auditrequest.resources = resources
274
303
  auditrequest.type = AuditRequest.SEARCH
275
- auditrequest.time.FromDatetime(datetime.now())
276
-
277
- auditrequest.trace_id = get_trace_id()
278
- await self.send(auditrequest)
279
-
280
- self.kb_usage_utility.send_kb_usage(
281
- service=Service.NUCLIA_DB,
282
- account_id=None,
283
- kb_id=kbid,
284
- kb_source=KBSource.HOSTED,
285
- # TODO unify AuditRequest client type and Nuclia Usage client type
286
- searches=[
287
- Search(
288
- client=ClientTypeKbUsage.Value(ClientType.Name(client_type)), # type: ignore
289
- type=SearchType.SEARCH,
290
- tokens=2000,
291
- num_searches=1,
292
- )
293
- ],
294
- )
295
304
 
296
- async def suggest(
305
+ def chat(
297
306
  self,
298
307
  kbid: str,
299
308
  user: str,
300
309
  client_type: int,
301
310
  origin: str,
302
- timeit: float,
303
- ):
304
- auditrequest = AuditRequest()
305
- auditrequest.origin = origin
306
- auditrequest.client_type = client_type # type: ignore
307
- auditrequest.userid = user
308
- auditrequest.kbid = kbid
309
- auditrequest.timeit = timeit
310
- auditrequest.type = AuditRequest.SUGGEST
311
- auditrequest.time.FromDatetime(datetime.now())
312
- auditrequest.trace_id = get_trace_id()
313
-
314
- await self.send(auditrequest)
315
-
316
- self.kb_usage_utility.send_kb_usage(
317
- service=Service.NUCLIA_DB,
318
- account_id=None,
319
- kb_id=kbid,
320
- kb_source=KBSource.HOSTED,
321
- # TODO unify AuditRequest client type and Nuclia Usage client type
322
- searches=[
323
- Search(
324
- client=ClientTypeKbUsage.Value(ClientType.Name(client_type)), # type: ignore
325
- type=SearchType.SUGGEST,
326
- tokens=0,
327
- num_searches=1,
328
- )
329
- ],
330
- )
331
-
332
- async def chat(
333
- self,
334
- kbid: str,
335
- user: str,
336
- client_type: int,
337
- origin: str,
338
- timeit: float,
339
311
  question: str,
340
312
  rephrased_question: Optional[str],
341
313
  context: List[ChatContext],
342
314
  answer: Optional[str],
343
315
  learning_id: str,
316
+ rephrase_time: Optional[float] = None,
317
+ generative_answer_time: Optional[float] = None,
318
+ generative_answer_first_chunk_time: Optional[float] = None,
344
319
  ):
345
- auditrequest = AuditRequest()
320
+ rcontext = get_request_context()
321
+ if rcontext is None:
322
+ return
323
+
324
+ auditrequest = rcontext.audit_request
325
+
346
326
  auditrequest.origin = origin
347
327
  auditrequest.client_type = client_type # type: ignore
348
328
  auditrequest.userid = user
349
329
  auditrequest.kbid = kbid
350
- auditrequest.timeit = timeit
330
+ if rephrase_time is not None:
331
+ auditrequest.rephrase_time = rephrase_time
332
+ if generative_answer_time is not None:
333
+ auditrequest.generative_answer_time = generative_answer_time
334
+ if generative_answer_first_chunk_time is not None:
335
+ auditrequest.generative_answer_first_chunk_time = generative_answer_first_chunk_time
351
336
  auditrequest.type = AuditRequest.CHAT
352
- auditrequest.time.FromDatetime(datetime.now())
353
- auditrequest.trace_id = get_trace_id()
354
337
  auditrequest.chat.question = question
355
338
  auditrequest.chat.context.extend(context)
356
339
  auditrequest.chat.learning_id = learning_id
@@ -358,11 +341,3 @@ class StreamAuditStorage(AuditStorage):
358
341
  auditrequest.chat.rephrased_question = rephrased_question
359
342
  if answer is not None:
360
343
  auditrequest.chat.answer = answer
361
- await self.send(auditrequest)
362
-
363
-
364
- def get_trace_id() -> str:
365
- span = get_current_span()
366
- if span is None:
367
- return ""
368
- return format_trace_id(span.get_span_context().trace_id)
@@ -22,10 +22,9 @@ import logging
22
22
  from collections.abc import Iterable
23
23
  from contextlib import suppress
24
24
  from datetime import datetime, timezone
25
- from typing import Optional
26
-
27
- from nats.js.client import JetStreamContext
25
+ from typing import List, Optional
28
26
 
27
+ from nucliadb_utils.nats import NatsConnectionManager
29
28
  from nucliadb_utils.nuclia_usage.protos.kb_usage_pb2 import (
30
29
  KBSource,
31
30
  KbUsage,
@@ -40,29 +39,48 @@ logger = logging.getLogger(__name__)
40
39
 
41
40
 
42
41
  class KbUsageReportUtility:
42
+ task: Optional[asyncio.Task]
43
+ initialized: bool
43
44
  queue: asyncio.Queue
44
- lock: asyncio.Lock
45
+ service: str
45
46
 
46
47
  def __init__(
47
48
  self,
48
- nats_stream: JetStreamContext,
49
49
  nats_subject: str,
50
+ nats_servers: List[str],
51
+ nats_creds: Optional[str] = None,
50
52
  max_queue_size: int = 100,
53
+ service: str = "",
51
54
  ):
52
- self.nats_stream = nats_stream
55
+ self.nats_connection_manager = NatsConnectionManager(
56
+ service_name=service,
57
+ nats_servers=nats_servers,
58
+ nats_creds=nats_creds,
59
+ )
53
60
  self.nats_subject = nats_subject
54
61
  self.queue = asyncio.Queue(max_queue_size)
55
62
  self.task = None
63
+ self.initialized = False
56
64
 
57
65
  async def initialize(self):
58
- if self.task is None:
59
- self.task = asyncio.create_task(self.run())
66
+ if not self.initialized and self.nats_connection_manager._nats_servers:
67
+ await self.nats_connection_manager.initialize()
68
+
69
+ if self.task is None:
70
+ self.task = asyncio.create_task(self.run())
71
+
72
+ self.initialized = True
60
73
 
61
74
  async def finalize(self):
62
- if self.task is not None:
63
- self.task.cancel()
64
- with suppress(asyncio.CancelledError, asyncio.exceptions.TimeoutError):
65
- await asyncio.wait_for(self.task, timeout=2)
75
+ if self.initialized:
76
+ if self.task is not None:
77
+ self.task.cancel()
78
+ with suppress(asyncio.CancelledError, asyncio.exceptions.TimeoutError):
79
+ await asyncio.wait_for(self.task, timeout=2)
80
+
81
+ await self.nats_connection_manager.finalize()
82
+
83
+ self.initialized = False
66
84
 
67
85
  async def run(self) -> None:
68
86
  while True:
@@ -75,13 +93,15 @@ class KbUsageReportUtility:
75
93
  self.queue.task_done()
76
94
 
77
95
  def send(self, message: KbUsage):
96
+ if not self.initialized:
97
+ return
78
98
  try:
79
99
  self.queue.put_nowait(message)
80
100
  except asyncio.QueueFull:
81
101
  logger.warning("KbUsage utility queue is full, dropping message")
82
102
 
83
103
  async def _send(self, message: KbUsage) -> int:
84
- res = await self.nats_stream.publish(
104
+ res = await self.nats_connection_manager.js.publish(
85
105
  self.nats_subject,
86
106
  message.SerializeToString(),
87
107
  )
@@ -200,6 +200,16 @@ class AuditSettings(BaseSettings):
200
200
  audit_settings = AuditSettings()
201
201
 
202
202
 
203
+ class UsageSettings(BaseSettings):
204
+ usage_jetstream_subject: Optional[str] = "kb-usage.nuclia_db"
205
+ usage_jetstream_servers: List[str] = []
206
+ usage_jetstream_auth: Optional[str] = None
207
+ usage_stream: str = "kb-usage"
208
+
209
+
210
+ usage_settings = UsageSettings()
211
+
212
+
203
213
  class NATSConsumerSettings(BaseSettings):
204
214
  # Read about message ordering:
205
215
  # https://docs.nats.io/nats-concepts/subject_mapping#when-is-deterministic-partitioning-needed
@@ -191,7 +191,11 @@ def start_gnatsd(gnatsd: Gnatsd): # pragma: no cover
191
191
 
192
192
  @pytest.fixture(scope="session")
193
193
  def natsd_server(): # pragma: no cover
194
- if not os.path.isfile("nats-server"):
194
+ # Create a persistent temporary directory
195
+ tmpdir = tempfile.mkdtemp()
196
+ nats_server_path = os.path.join(tmpdir, "nats-server")
197
+
198
+ if not os.path.isfile(nats_server_path):
195
199
  version = "v2.10.12"
196
200
  arch = platform.machine()
197
201
  if arch == "x86_64":
@@ -205,13 +209,13 @@ def natsd_server(): # pragma: no cover
205
209
 
206
210
  file = zipfile.open(f"nats-server-{version}-{system}-{arch}/nats-server")
207
211
  content = file.read()
208
- with open("nats-server", "wb") as f:
212
+ with open(nats_server_path, "wb") as f:
209
213
  f.write(content)
210
- os.chmod("nats-server", 755)
214
+ os.chmod(nats_server_path, 0o755)
211
215
 
212
216
  server = Gnatsd(port=4222)
213
217
  server.bin_name = "nats-server"
214
- server.path = os.getcwd()
218
+ server.path = tmpdir
215
219
  return server
216
220
 
217
221
 
@@ -40,6 +40,7 @@ from nucliadb_utils.encryption.settings import settings as encryption_settings
40
40
  from nucliadb_utils.exceptions import ConfigurationError
41
41
  from nucliadb_utils.indexing import IndexingUtility
42
42
  from nucliadb_utils.nats import NatsConnectionManager
43
+ from nucliadb_utils.nuclia_usage.utils.kb_usage_report import KbUsageReportUtility
43
44
  from nucliadb_utils.partition import PartitionUtility
44
45
  from nucliadb_utils.settings import (
45
46
  FileBackendConfig,
@@ -48,6 +49,7 @@ from nucliadb_utils.settings import (
48
49
  nuclia_settings,
49
50
  storage_settings,
50
51
  transaction_settings,
52
+ usage_settings,
51
53
  )
52
54
  from nucliadb_utils.storages.settings import settings as extended_storage_settings
53
55
  from nucliadb_utils.store import MAIN
@@ -81,6 +83,7 @@ class Utility(str, Enum):
81
83
  LOCAL_STORAGE = "local_storage"
82
84
  NUCLIA_STORAGE = "nuclia_storage"
83
85
  MAINDB_DRIVER = "driver"
86
+ USAGE = "usage"
84
87
  ENDECRYPTOR = "endecryptor"
85
88
  PINECONE_SESSION = "pinecone_session"
86
89
 
@@ -322,16 +325,41 @@ def get_audit() -> Optional[AuditStorage]:
322
325
  return get_utility(Utility.AUDIT)
323
326
 
324
327
 
325
- async def start_audit_utility(service: str):
326
- audit_utility: Optional[AuditStorage] = get_utility(Utility.AUDIT)
327
- if audit_utility is not None:
328
+ def get_usage_utility() -> Optional[KbUsageReportUtility]:
329
+ return get_utility(Utility.USAGE)
330
+
331
+
332
+ async def start_usage_utility(service: str):
333
+ usage_utility: Optional[KbUsageReportUtility] = get_utility(Utility.USAGE)
334
+ if usage_utility is not None:
328
335
  return
329
336
 
337
+ usage_utility = KbUsageReportUtility(
338
+ nats_subject=cast(str, usage_settings.usage_jetstream_subject),
339
+ nats_servers=usage_settings.usage_jetstream_servers,
340
+ nats_creds=usage_settings.usage_jetstream_auth,
341
+ service=service,
342
+ )
343
+ logger.info(f"Configuring usage report utility {usage_settings.usage_jetstream_subject}")
344
+ await usage_utility.initialize()
345
+ set_utility(Utility.USAGE, usage_utility)
346
+
347
+
348
+ async def stop_usage_utility():
349
+ usage_utility = get_usage_utility()
350
+ if usage_utility:
351
+ await usage_utility.finalize()
352
+ clean_utility(Utility.USAGE)
353
+
354
+
355
+ def register_audit_utility(service: str) -> AuditStorage:
330
356
  if audit_settings.audit_driver == "basic":
331
- audit_utility = BasicAuditStorage()
357
+ b_audit_utility: AuditStorage = BasicAuditStorage()
358
+ set_utility(Utility.AUDIT, b_audit_utility)
332
359
  logger.info("Configuring basic audit log")
360
+ return b_audit_utility
333
361
  elif audit_settings.audit_driver == "stream":
334
- audit_utility = StreamAuditStorage(
362
+ s_audit_utility: AuditStorage = StreamAuditStorage(
335
363
  nats_creds=audit_settings.audit_jetstream_auth,
336
364
  nats_servers=audit_settings.audit_jetstream_servers,
337
365
  nats_target=cast(str, audit_settings.audit_jetstream_target),
@@ -339,11 +367,22 @@ async def start_audit_utility(service: str):
339
367
  seed=audit_settings.audit_hash_seed,
340
368
  service=service,
341
369
  )
370
+ set_utility(Utility.AUDIT, s_audit_utility)
342
371
  logger.info(f"Configuring stream audit log {audit_settings.audit_jetstream_target}")
372
+ return s_audit_utility
343
373
  else:
344
374
  raise ConfigurationError("Invalid audit driver")
345
- await audit_utility.initialize()
346
- set_utility(Utility.AUDIT, audit_utility)
375
+
376
+
377
+ async def start_audit_utility(service: str):
378
+ audit_utility: Optional[AuditStorage] = get_utility(Utility.AUDIT)
379
+ if audit_utility is not None and audit_utility.initialized is True:
380
+ return
381
+
382
+ if audit_utility is None:
383
+ audit_utility = register_audit_utility(service)
384
+ if audit_utility.initialized is False:
385
+ await audit_utility.initialize()
347
386
 
348
387
 
349
388
  async def stop_audit_utility():
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nucliadb_utils
3
- Version: 5.0.0.post866
3
+ Version: 5.0.0.post899
4
4
  Home-page: https://nuclia.com
5
5
  License: BSD
6
6
  Classifier: Development Status :: 4 - Beta
@@ -23,8 +23,8 @@ Requires-Dist: PyNaCl
23
23
  Requires-Dist: pyjwt >=2.4.0
24
24
  Requires-Dist: memorylru >=1.1.2
25
25
  Requires-Dist: mrflagly
26
- Requires-Dist: nucliadb-protos >=5.0.0.post866
27
- Requires-Dist: nucliadb-telemetry >=5.0.0.post866
26
+ Requires-Dist: nucliadb-protos >=5.0.0.post899
27
+ Requires-Dist: nucliadb-telemetry >=5.0.0.post899
28
28
  Provides-Extra: cache
29
29
  Requires-Dist: redis >=4.3.4 ; extra == 'cache'
30
30
  Requires-Dist: orjson >=3.6.7 ; extra == 'cache'
@@ -12,19 +12,19 @@ nucliadb_utils/nats.py,sha256=zTAXECDXeCPtydk3F_6EMFDZ059kK0UYUU_tnWoxgXs,8208
12
12
  nucliadb_utils/partition.py,sha256=jBgy4Hu5Iwn4gjbPPcthSykwf-qNx-GcLAIwbzPd1d0,1157
13
13
  nucliadb_utils/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  nucliadb_utils/run.py,sha256=HpAIM8xbR7UpVC2_7xOjB4fYbUVykyPP6yHrv2RD3DI,1707
15
- nucliadb_utils/settings.py,sha256=AaOtQZVRqRcMnUyN1l1MpR10lANaDT2uPrbhmTyn6uk,7647
15
+ nucliadb_utils/settings.py,sha256=QR51SX0T17-_YofMNpci-nkI77l_CUWFy4H7i8hNOHU,7911
16
16
  nucliadb_utils/signals.py,sha256=JRNv2y9zLtBjOANBf7krGfDGfOc9qcoXZ6N1nKWS2FE,2674
17
17
  nucliadb_utils/store.py,sha256=kQ35HemE0v4_Qg6xVqNIJi8vSFAYQtwI3rDtMsNy62Y,890
18
18
  nucliadb_utils/transaction.py,sha256=mwcI3aIHAvU5KOGqd_Uz_d1XQzXhk_-NWY8NqU1lfb0,7307
19
- nucliadb_utils/utilities.py,sha256=0NNoWLAR88Kb6QuHOK08ZIhrjAKsQ46iQqwN3-I9S6w,15355
19
+ nucliadb_utils/utilities.py,sha256=jjapoJvgZ-H0uMoThVK8P6EvCTTwL0kM_EEJY-Q22EU,16747
20
20
  nucliadb_utils/aiopynecone/__init__.py,sha256=cp15ZcFnHvpcu_5-aK2A4uUyvuZVV_MJn4bIXMa20ks,835
21
21
  nucliadb_utils/aiopynecone/client.py,sha256=jlTLOdphrLZElKW9ajVzDcebN7CSAbjNyX6tnD37eqY,18721
22
22
  nucliadb_utils/aiopynecone/exceptions.py,sha256=hFhq-UEY4slqNWjObXr_LPnRf_AQ1vpcG4SF2XRFd1E,2873
23
23
  nucliadb_utils/aiopynecone/models.py,sha256=B_ihJhHZGp3ivQVUxhV49uoUnHe1PLDKxTgHNbHgSS0,2937
24
24
  nucliadb_utils/audit/__init__.py,sha256=cp15ZcFnHvpcu_5-aK2A4uUyvuZVV_MJn4bIXMa20ks,835
25
- nucliadb_utils/audit/audit.py,sha256=dn5ZnCVQUlCcvdjzaORghbrjk9QgVGrtkfIftq30Bp8,2819
26
- nucliadb_utils/audit/basic.py,sha256=NViey6mKbCXqRTLDBX2xNTcCg9I-2e4oB2xkekuhDvM,3392
27
- nucliadb_utils/audit/stream.py,sha256=CSNFrwCT_IeHFStx4q4ImgOQ7v0WzJr4YG6QRzE7YlQ,11952
25
+ nucliadb_utils/audit/audit.py,sha256=kpjiTLgWtLXA5DNgJcmeIaodel11Mz0Y-JfQt-yfhis,2641
26
+ nucliadb_utils/audit/basic.py,sha256=LwCCvUAwy8LNzYdrw3NebxqNdUumCEyRQ244dJ93JfE,3190
27
+ nucliadb_utils/audit/stream.py,sha256=_RZf1G9_c5eaxfutq1OdMmnRaRM7EiaXH33Kb4mN9Gw,11525
28
28
  nucliadb_utils/cache/__init__.py,sha256=itSI7dtTwFP55YMX4iK7JzdMHS5CQVUiB1XzQu4UBh8,833
29
29
  nucliadb_utils/cache/exceptions.py,sha256=Zu-O_-0-yctOEgoDGI92gPzWfBMRrpiAyESA62ld6MA,975
30
30
  nucliadb_utils/cache/nats.py,sha256=-AjCfkFgKVdJUlGR0hT9JDSNkPVFg4S6w9eW-ZIcXPM,7037
@@ -43,7 +43,7 @@ nucliadb_utils/nuclia_usage/protos/kb_usage_pb2.pyi,sha256=xhyc3jJBh0KZuWcgmIbwS
43
43
  nucliadb_utils/nuclia_usage/protos/kb_usage_pb2_grpc.py,sha256=dhop8WwjplPfORYPYb9HtcS9gHMzqXPJQGqXYRjV-6M,1008
44
44
  nucliadb_utils/nuclia_usage/protos/kb_usage_pb2_grpc.pyi,sha256=6RIsZ2934iodEckflpBStgLKEkFhKfNmZ72UKg2Bwb4,911
45
45
  nucliadb_utils/nuclia_usage/utils/__init__.py,sha256=cp15ZcFnHvpcu_5-aK2A4uUyvuZVV_MJn4bIXMa20ks,835
46
- nucliadb_utils/nuclia_usage/utils/kb_usage_report.py,sha256=E1eUSFXBVNzQP9Q2rWj9y3koCO5S7iKwckny_AoLKuk,3870
46
+ nucliadb_utils/nuclia_usage/utils/kb_usage_report.py,sha256=t4-QXUTOTeMBWFFO8hdqiIco1fgC2Jla4NIWuRSs6N4,4566
47
47
  nucliadb_utils/storages/__init__.py,sha256=5Qc8AUWiJv9_JbGCBpAn88AIJhwDlm0OPQpg2ZdRL4U,872
48
48
  nucliadb_utils/storages/azure.py,sha256=egMDwLNIGSQyVevuySt2AswzFdNAcih05BbRg3-p8IU,16015
49
49
  nucliadb_utils/storages/exceptions.py,sha256=mm_wX4YRtp7u7enkk_4pMSlX5AQQuFbq4xLmupVDt3Y,2502
@@ -62,10 +62,10 @@ nucliadb_utils/tests/fixtures.py,sha256=rmix1VGpPULBxJd_5ScgVD4Z0nRrkNOVt-uTW_Wo
62
62
  nucliadb_utils/tests/gcs.py,sha256=Ii8BCHUAAxFIzX67pKTRFRgbqv3FJ6DrPAdAx2Xod1Y,3036
63
63
  nucliadb_utils/tests/indexing.py,sha256=YW2QhkhO9Q_8A4kKWJaWSvXvyQ_AiAwY1VylcfVQFxk,1513
64
64
  nucliadb_utils/tests/local.py,sha256=c3gZJJWmvOftruJkIQIwB3q_hh3uxEhqGIAVWim1Bbk,1343
65
- nucliadb_utils/tests/nats.py,sha256=Tosonm9A9cusImyji80G4pgdXEHNVPaCLT5TbFK_ra0,7543
65
+ nucliadb_utils/tests/nats.py,sha256=xqpww4jZjTKY9oPGlJdDJG67L3FIBQsa9qDHxILR8r8,7687
66
66
  nucliadb_utils/tests/s3.py,sha256=YB8QqDaBXxyhHonEHmeBbRRDmvB7sTOaKBSi8KBGokg,2330
67
- nucliadb_utils-5.0.0.post866.dist-info/METADATA,sha256=k8REPv3DfxUYJjoLcSlAWEr9NmUJNx2rPpPW2owrHeQ,2073
68
- nucliadb_utils-5.0.0.post866.dist-info/WHEEL,sha256=-oYQCr74JF3a37z2nRlQays_SX2MqOANoqVjBBAP2yE,91
69
- nucliadb_utils-5.0.0.post866.dist-info/top_level.txt,sha256=fE3vJtALTfgh7bcAWcNhcfXkNPp_eVVpbKK-2IYua3E,15
70
- nucliadb_utils-5.0.0.post866.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
71
- nucliadb_utils-5.0.0.post866.dist-info/RECORD,,
67
+ nucliadb_utils-5.0.0.post899.dist-info/METADATA,sha256=--jd5upWgPLSU_qpT3aCoyyMFyHhxq5vkrtrGdXGXSM,2073
68
+ nucliadb_utils-5.0.0.post899.dist-info/WHEEL,sha256=Wyh-_nZ0DJYolHNn1_hMa4lM7uDedD_RGVwbmTjyItk,91
69
+ nucliadb_utils-5.0.0.post899.dist-info/top_level.txt,sha256=fE3vJtALTfgh7bcAWcNhcfXkNPp_eVVpbKK-2IYua3E,15
70
+ nucliadb_utils-5.0.0.post899.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
71
+ nucliadb_utils-5.0.0.post899.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (71.0.3)
2
+ Generator: setuptools (71.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5