nucliadb 6.3.1.post3524__py3-none-any.whl → 6.3.1.post3531__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,249 @@
1
+ # Copyright (C) 2021 Bosutech XXI S.L.
2
+ #
3
+ # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
+ # For commercial licensing, contact us at info@nuclia.com.
5
+ #
6
+ # AGPL:
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU Affero General Public License as
9
+ # published by the Free Software Foundation, either version 3 of the
10
+ # License, or (at your option) any later version.
11
+ #
12
+ # This program is distributed in the hope that it will be useful,
13
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
+ # GNU Affero General Public License for more details.
16
+ #
17
+ # You should have received a copy of the GNU Affero General Public License
18
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
+ #
20
+
21
+
22
+ import asyncio
23
+ import functools
24
+ import tarfile
25
+ from typing import AsyncIterator, Callable, Optional, Union
26
+
27
+ from nucliadb.backups.const import MaindbKeys, StorageKeys
28
+ from nucliadb.backups.models import RestoreBackupRequest
29
+ from nucliadb.backups.settings import settings
30
+ from nucliadb.common.context import ApplicationContext
31
+ from nucliadb.export_import.utils import (
32
+ import_binary,
33
+ import_broker_message,
34
+ set_entities_groups,
35
+ set_labels,
36
+ )
37
+ from nucliadb.tasks.retries import TaskRetryHandler
38
+ from nucliadb_protos import knowledgebox_pb2 as kb_pb2
39
+ from nucliadb_protos.resources_pb2 import CloudFile
40
+ from nucliadb_protos.writer_pb2 import BrokerMessage
41
+
42
+
43
+ async def restore_kb_retried(context: ApplicationContext, msg: RestoreBackupRequest):
44
+ kbid = msg.kbid
45
+ backup_id = msg.backup_id
46
+
47
+ retry_handler = TaskRetryHandler(
48
+ kbid=kbid,
49
+ task_type="restore",
50
+ task_id=backup_id,
51
+ context=context,
52
+ max_retries=3,
53
+ )
54
+
55
+ @retry_handler.wrap
56
+ async def _restore_kb(context: ApplicationContext, kbid: str, backup_id: str):
57
+ await restore_kb(context, kbid, backup_id)
58
+
59
+ await _restore_kb(context, kbid, backup_id)
60
+
61
+
62
+ async def restore_kb(context: ApplicationContext, kbid: str, backup_id: str):
63
+ """
64
+ Downloads the backup files from the cloud storage and imports them into the KB.
65
+ """
66
+ await restore_resources(context, kbid, backup_id)
67
+ await restore_labels(context, kbid, backup_id)
68
+ await restore_entities(context, kbid, backup_id)
69
+ await delete_last_restored_resource_key(context, kbid, backup_id)
70
+
71
+
72
+ async def restore_resources(context: ApplicationContext, kbid: str, backup_id: str):
73
+ last_restored = await get_last_restored_resource_key(context, kbid, backup_id)
74
+ tasks = []
75
+ async for object_info in context.blob_storage.iterate_objects(
76
+ bucket=settings.backups_bucket,
77
+ prefix=StorageKeys.RESOURCES_PREFIX.format(kbid=kbid, backup_id=backup_id),
78
+ start=last_restored,
79
+ ):
80
+ key = object_info.name
81
+ resource_id = key.split("/")[-1].rstrip(".tar")
82
+ tasks.append(asyncio.create_task(restore_resource(context, kbid, backup_id, resource_id)))
83
+ if len(tasks) > settings.restore_resources_concurrency:
84
+ await asyncio.gather(*tasks)
85
+ tasks = []
86
+ await set_last_restored_resource_key(context, kbid, backup_id, key)
87
+ if len(tasks) > 0:
88
+ await asyncio.gather(*tasks)
89
+ tasks = []
90
+ await set_last_restored_resource_key(context, kbid, backup_id, key)
91
+
92
+
93
+ async def get_last_restored_resource_key(
94
+ context: ApplicationContext, kbid: str, backup_id: str
95
+ ) -> Optional[str]:
96
+ key = MaindbKeys.LAST_RESTORED.format(kbid=kbid, backup_id=backup_id)
97
+ async with context.kv_driver.transaction(read_only=True) as txn:
98
+ raw = await txn.get(key)
99
+ if raw is None:
100
+ return None
101
+ return raw.decode()
102
+
103
+
104
+ async def set_last_restored_resource_key(
105
+ context: ApplicationContext, kbid: str, backup_id: str, resource_id: str
106
+ ):
107
+ key = MaindbKeys.LAST_RESTORED.format(kbid=kbid, backup_id=backup_id)
108
+ async with context.kv_driver.transaction() as txn:
109
+ await txn.set(key, resource_id.encode())
110
+ await txn.commit()
111
+
112
+
113
+ async def delete_last_restored_resource_key(context: ApplicationContext, kbid: str, backup_id: str):
114
+ key = MaindbKeys.LAST_RESTORED.format(kbid=kbid, backup_id=backup_id)
115
+ async with context.kv_driver.transaction() as txn:
116
+ await txn.delete(key)
117
+ await txn.commit()
118
+
119
+
120
+ class CloudFileBinary:
121
+ def __init__(self, uri: str, download_stream: Callable[[int], AsyncIterator[bytes]]):
122
+ self.uri = uri
123
+ self.download_stream = download_stream
124
+
125
+ async def read(self, chunk_size: int) -> AsyncIterator[bytes]:
126
+ async for chunk in self.download_stream(chunk_size):
127
+ yield chunk
128
+
129
+
130
+ class ResourceBackupReader:
131
+ def __init__(self, download_stream: AsyncIterator[bytes]):
132
+ self.download_stream = download_stream
133
+ self.buffer = b""
134
+
135
+ async def read(self, size: int) -> bytes:
136
+ while len(self.buffer) < size:
137
+ chunk = await self.download_stream.__anext__()
138
+ self.buffer += chunk
139
+ result = self.buffer[:size]
140
+ self.buffer = self.buffer[size:]
141
+ return result
142
+
143
+ async def iter_data(self, total_bytes: int, chunk_size: int = 1024 * 1024) -> AsyncIterator[bytes]:
144
+ padding_bytes = 0
145
+ if total_bytes % 512 != 0:
146
+ # We need to read the padding bytes and then discard them
147
+ padding_bytes = 512 - (total_bytes % 512)
148
+ read_bytes = 0
149
+ padding_reached = False
150
+ async for chunk in self._iter(total_bytes + padding_bytes, chunk_size):
151
+ if padding_reached:
152
+ # Skip padding bytes. We can't break here because we need
153
+ # to read the padding bytes from the stream
154
+ continue
155
+ padding_reached = read_bytes + len(chunk) >= total_bytes
156
+ if padding_reached:
157
+ chunk = chunk[: total_bytes - read_bytes]
158
+ else:
159
+ read_bytes += len(chunk)
160
+ yield chunk
161
+
162
+ async def _iter(self, total_bytes: int, chunk_size: int = 1024 * 1024) -> AsyncIterator[bytes]:
163
+ remaining_bytes = total_bytes
164
+ while remaining_bytes > 0:
165
+ to_read = min(chunk_size, remaining_bytes)
166
+ chunk = await self.read(to_read)
167
+ yield chunk
168
+ remaining_bytes -= len(chunk)
169
+ assert remaining_bytes == 0
170
+
171
+ async def read_tarinfo(self):
172
+ raw_tar_header = await self.read(512)
173
+ return tarfile.TarInfo.frombuf(raw_tar_header, encoding="utf-8", errors="strict")
174
+
175
+ async def read_data(self, tarinfo: tarfile.TarInfo) -> bytes:
176
+ tarinfo_size = tarinfo.size
177
+ padding_bytes = 0
178
+ if tarinfo_size % 512 != 0:
179
+ # We need to read the padding bytes and then discard them
180
+ padding_bytes = 512 - (tarinfo_size % 512)
181
+ data = await self.read(tarinfo_size + padding_bytes)
182
+ return data[:tarinfo_size]
183
+
184
+ async def read_item(self) -> Union[BrokerMessage, CloudFile, CloudFileBinary]:
185
+ tarinfo = await self.read_tarinfo()
186
+ if tarinfo.name.startswith("broker-message"):
187
+ raw_bm = await self.read_data(tarinfo)
188
+ bm = BrokerMessage()
189
+ bm.ParseFromString(raw_bm)
190
+ return bm
191
+ elif tarinfo.name.startswith("cloud-files"):
192
+ raw_cf = await self.read_data(tarinfo)
193
+ cf = CloudFile()
194
+ cf.FromString(raw_cf)
195
+ return cf
196
+ elif tarinfo.name.startswith("binaries"):
197
+ uri = tarinfo.name.lstrip("binaries/")
198
+ size = tarinfo.size
199
+ download_stream = functools.partial(self.iter_data, size)
200
+ return CloudFileBinary(uri, download_stream)
201
+ else: # pragma: no cover
202
+ raise ValueError(f"Unknown tar entry: {tarinfo.name}")
203
+
204
+
205
+ async def restore_resource(context: ApplicationContext, kbid: str, backup_id: str, resource_id: str):
206
+ download_stream = context.blob_storage.download(
207
+ bucket=settings.backups_bucket,
208
+ key=StorageKeys.RESOURCE.format(kbid=kbid, backup_id=backup_id, resource_id=resource_id),
209
+ )
210
+ reader = ResourceBackupReader(download_stream)
211
+ bm = None
212
+ while True:
213
+ item = await reader.read_item()
214
+ if isinstance(item, BrokerMessage):
215
+ # When the broker message is read, this means all cloud files
216
+ # and binaries of that resource have been read and imported
217
+ bm = item
218
+ bm.kbid = kbid
219
+ break
220
+
221
+ # Read the cloud file and its binary
222
+ cf = await reader.read_item()
223
+ assert isinstance(cf, CloudFile)
224
+ cf_binary = await reader.read_item()
225
+ assert isinstance(cf_binary, CloudFileBinary)
226
+ assert cf.uri == cf_binary.uri
227
+ await import_binary(context, kbid, cf, cf_binary.read)
228
+
229
+ await import_broker_message(context, kbid, bm)
230
+
231
+
232
+ async def restore_labels(context: ApplicationContext, kbid: str, backup_id: str):
233
+ raw = await context.blob_storage.downloadbytes(
234
+ bucket=settings.backups_bucket,
235
+ key=StorageKeys.LABELS.format(kbid=kbid, backup_id=backup_id),
236
+ )
237
+ labels = kb_pb2.Labels()
238
+ labels.ParseFromString(raw.getvalue())
239
+ await set_labels(context, kbid, labels)
240
+
241
+
242
+ async def restore_entities(context: ApplicationContext, kbid: str, backup_id: str):
243
+ raw = await context.blob_storage.downloadbytes(
244
+ bucket=settings.backups_bucket,
245
+ key=StorageKeys.ENTITIES.format(kbid=kbid, backup_id=backup_id),
246
+ )
247
+ entities = kb_pb2.EntitiesGroups()
248
+ entities.ParseFromString(raw.getvalue())
249
+ await set_entities_groups(context, kbid, entities)
@@ -0,0 +1,37 @@
1
+ # Copyright (C) 2021 Bosutech XXI S.L.
2
+ #
3
+ # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
+ # For commercial licensing, contact us at info@nuclia.com.
5
+ #
6
+ # AGPL:
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU Affero General Public License as
9
+ # published by the Free Software Foundation, either version 3 of the
10
+ # License, or (at your option) any later version.
11
+ #
12
+ # This program is distributed in the hope that it will be useful,
13
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
+ # GNU Affero General Public License for more details.
16
+ #
17
+ # You should have received a copy of the GNU Affero General Public License
18
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
+ #
20
+
21
+ from pydantic import Field
22
+ from pydantic_settings import BaseSettings
23
+
24
+
25
+ class BackupSettings(BaseSettings):
26
+ backups_bucket: str = Field(
27
+ default="backups", description="The bucket where the backups are stored."
28
+ )
29
+ restore_resources_concurrency: int = Field(
30
+ default=10, description="The number of concurrent resource restores."
31
+ )
32
+ backup_resources_concurrency: int = Field(
33
+ default=10, description="The number of concurrent resource backups."
34
+ )
35
+
36
+
37
+ settings = BackupSettings()
@@ -0,0 +1,126 @@
1
+ # Copyright (C) 2021 Bosutech XXI S.L.
2
+ #
3
+ # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
+ # For commercial licensing, contact us at info@nuclia.com.
5
+ #
6
+ # AGPL:
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU Affero General Public License as
9
+ # published by the Free Software Foundation, either version 3 of the
10
+ # License, or (at your option) any later version.
11
+ #
12
+ # This program is distributed in the hope that it will be useful,
13
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
+ # GNU Affero General Public License for more details.
16
+ #
17
+ # You should have received a copy of the GNU Affero General Public License
18
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
+ #
20
+ from typing import Awaitable, Callable
21
+
22
+ from nucliadb.backups.create import backup_kb_retried
23
+ from nucliadb.backups.delete import delete_backup
24
+ from nucliadb.backups.models import CreateBackupRequest, DeleteBackupRequest, RestoreBackupRequest
25
+ from nucliadb.backups.restore import restore_kb_retried
26
+ from nucliadb.common.context import ApplicationContext
27
+ from nucliadb.tasks import create_consumer, create_producer
28
+ from nucliadb.tasks.consumer import NatsTaskConsumer
29
+ from nucliadb.tasks.producer import NatsTaskProducer
30
+
31
+
32
+ def creator_consumer() -> NatsTaskConsumer[CreateBackupRequest]:
33
+ consumer: NatsTaskConsumer = create_consumer(
34
+ name="backup_creator",
35
+ stream="backups",
36
+ stream_subjects=["backups.>"],
37
+ consumer_subject="backups.create",
38
+ callback=backup_kb_retried,
39
+ msg_type=CreateBackupRequest,
40
+ max_concurrent_messages=10,
41
+ )
42
+ return consumer
43
+
44
+
45
+ async def create(kbid: str, backup_id: str) -> None:
46
+ producer: NatsTaskProducer[CreateBackupRequest] = create_producer(
47
+ name="backup_creator",
48
+ stream="backups",
49
+ stream_subjects=["backups.>"],
50
+ producer_subject="backups.create",
51
+ msg_type=CreateBackupRequest,
52
+ )
53
+ msg = CreateBackupRequest(
54
+ kbid=kbid,
55
+ backup_id=backup_id,
56
+ )
57
+ await producer.send(msg)
58
+
59
+
60
+ def restorer_consumer() -> NatsTaskConsumer[RestoreBackupRequest]:
61
+ consumer: NatsTaskConsumer = create_consumer(
62
+ name="backup_restorer",
63
+ stream="backups",
64
+ stream_subjects=["backups.>"],
65
+ consumer_subject="backups.restore",
66
+ callback=restore_kb_retried,
67
+ msg_type=RestoreBackupRequest,
68
+ max_concurrent_messages=10,
69
+ )
70
+ return consumer
71
+
72
+
73
+ async def restore(kbid: str, backup_id: str) -> None:
74
+ producer: NatsTaskProducer[RestoreBackupRequest] = create_producer(
75
+ name="backup_restorer",
76
+ stream="backups",
77
+ stream_subjects=["backups.>"],
78
+ producer_subject="backups.restore",
79
+ msg_type=RestoreBackupRequest,
80
+ )
81
+ msg = RestoreBackupRequest(
82
+ kbid=kbid,
83
+ backup_id=backup_id,
84
+ )
85
+ await producer.send(msg)
86
+
87
+
88
+ def deleter_consumer() -> NatsTaskConsumer[DeleteBackupRequest]:
89
+ consumer: NatsTaskConsumer = create_consumer(
90
+ name="backup_deleter",
91
+ stream="backups",
92
+ stream_subjects=["backups.>"],
93
+ consumer_subject="backups.delete",
94
+ callback=delete_backup,
95
+ msg_type=DeleteBackupRequest,
96
+ max_concurrent_messages=2,
97
+ )
98
+ return consumer
99
+
100
+
101
+ async def delete(backup_id: str) -> None:
102
+ producer: NatsTaskProducer[DeleteBackupRequest] = create_producer(
103
+ name="backup_deleter",
104
+ stream="backups",
105
+ stream_subjects=["backups.>"],
106
+ producer_subject="backups.delete",
107
+ msg_type=DeleteBackupRequest,
108
+ )
109
+ msg = DeleteBackupRequest(
110
+ backup_id=backup_id,
111
+ )
112
+ await producer.send(msg)
113
+
114
+
115
+ async def initialize_consumers(context: ApplicationContext) -> list[Callable[[], Awaitable[None]]]:
116
+ creator = creator_consumer()
117
+ restorer = restorer_consumer()
118
+ deleter = deleter_consumer()
119
+ await creator.initialize(context)
120
+ await restorer.initialize(context)
121
+ await deleter.initialize(context)
122
+ return [
123
+ creator.finalize,
124
+ restorer.finalize,
125
+ deleter.finalize,
126
+ ]
@@ -0,0 +1,32 @@
1
+ # Copyright (C) 2021 Bosutech XXI S.L.
2
+ #
3
+ # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
+ # For commercial licensing, contact us at info@nuclia.com.
5
+ #
6
+ # AGPL:
7
+ # This program is free software: you can redistribute it and/or modify
8
+ # it under the terms of the GNU Affero General Public License as
9
+ # published by the Free Software Foundation, either version 3 of the
10
+ # License, or (at your option) any later version.
11
+ #
12
+ # This program is distributed in the hope that it will be useful,
13
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
+ # GNU Affero General Public License for more details.
16
+ #
17
+ # You should have received a copy of the GNU Affero General Public License
18
+ # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
+ #
20
+
21
+ from nucliadb.backups.const import StorageKeys
22
+ from nucliadb.backups.settings import settings
23
+ from nucliadb_utils.storages.storage import Storage
24
+
25
+
26
+ async def exists_backup(storage: Storage, backup_id: str) -> bool:
27
+ async for _ in storage.iterate_objects(
28
+ bucket=settings.backups_bucket,
29
+ prefix=StorageKeys.BACKUP_PREFIX.format(backup_id=backup_id),
30
+ ):
31
+ return True
32
+ return False
@@ -36,7 +36,7 @@ from nucliadb_protos.utils_pb2 import Relation
36
36
 
37
37
 
38
38
  class DummyWriterStub: # pragma: no cover
39
- def __init__(self):
39
+ def __init__(self: "DummyWriterStub"):
40
40
  self.calls: dict[str, list[Any]] = {}
41
41
 
42
42
  async def NewShard(self, data): # pragma: no cover
@@ -82,7 +82,7 @@ class DummyWriterStub: # pragma: no cover
82
82
 
83
83
 
84
84
  class DummyReaderStub: # pragma: no cover
85
- def __init__(self):
85
+ def __init__(self: "DummyReaderStub"):
86
86
  self.calls: dict[str, list[Any]] = {}
87
87
 
88
88
  async def GetShard(self, data): # pragma: no cover
@@ -281,7 +281,7 @@ class KBShardManager:
281
281
  class StandaloneKBShardManager(KBShardManager):
282
282
  max_ops_before_checks = 200
283
283
 
284
- def __init__(self):
284
+ def __init__(self: "StandaloneKBShardManager"):
285
285
  super().__init__()
286
286
  self._lock = asyncio.Lock()
287
287
  self._change_count: dict[tuple[str, str], int] = {}
@@ -35,8 +35,12 @@ it's transaction
35
35
 
36
36
  """
37
37
 
38
- import sys
39
38
  from functools import wraps
39
+ from typing import Awaitable, Callable, TypeVar
40
+
41
+ from typing_extensions import Concatenate, ParamSpec
42
+
43
+ from nucliadb.common.maindb.driver import Transaction
40
44
 
41
45
  from . import kb as kb_dm
42
46
  from . import labels as labels_dm
@@ -44,34 +48,24 @@ from . import resources as resources_dm
44
48
  from . import synonyms as synonyms_dm
45
49
  from .utils import with_ro_transaction, with_transaction
46
50
 
47
- # XXX: we are using the not exported _ParamSpec to support 3.9. Whenever we
48
- # upgrade to >= 3.10 we'll be able to use ParamSpecKwargs and improve the
49
- # typing. We are abusing of ParamSpec anywat to better support text editors, so
50
- # we also need to ignore some mypy complains
51
-
52
- __python_version = (sys.version_info.major, sys.version_info.minor)
53
- if __python_version == (3, 9):
54
- from typing_extensions import ParamSpec
55
- else:
56
- from typing import ParamSpec # type: ignore
57
-
58
51
  P = ParamSpec("P")
52
+ T = TypeVar("T")
59
53
 
60
54
 
61
- def ro_txn_wrap(fun: P) -> P: # type: ignore
55
+ def ro_txn_wrap(fun: Callable[Concatenate[Transaction, P], Awaitable[T]]) -> Callable[P, Awaitable[T]]:
62
56
  @wraps(fun)
63
- async def wrapper(**kwargs: P.kwargs):
57
+ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
64
58
  async with with_ro_transaction() as txn:
65
- return await fun(txn, **kwargs)
59
+ return await fun(txn, *args, **kwargs)
66
60
 
67
61
  return wrapper
68
62
 
69
63
 
70
- def rw_txn_wrap(fun: P) -> P: # type: ignore
64
+ def rw_txn_wrap(fun: Callable[Concatenate[Transaction, P], Awaitable[T]]) -> Callable[P, Awaitable[T]]:
71
65
  @wraps(fun)
72
- async def wrapper(**kwargs: P.kwargs):
66
+ async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
73
67
  async with with_transaction() as txn:
74
- result = await fun(txn, **kwargs)
68
+ result = await fun(txn, *args, **kwargs)
75
69
  await txn.commit()
76
70
  return result
77
71
 
@@ -41,7 +41,7 @@ class EntitiesMetaCache:
41
41
  change the structure of this class or we'll break the index.
42
42
  """
43
43
 
44
- def __init__(self):
44
+ def __init__(self: "EntitiesMetaCache") -> None:
45
45
  self.deleted_entities: dict[str, list[str]] = {}
46
46
  self.duplicate_entities: dict[str, dict[str, list[str]]] = {}
47
47
  # materialize by value for faster lookups
@@ -40,7 +40,7 @@ from nucliadb_protos import resources_pb2, writer_pb2
40
40
  from nucliadb_utils.const import Streams
41
41
  from nucliadb_utils.transaction import MaxTransactionSizeExceededError
42
42
 
43
- BinaryStream = AsyncGenerator[bytes, None]
43
+ BinaryStream = AsyncIterator[bytes]
44
44
  BinaryStreamGenerator = Callable[[int], BinaryStream]
45
45
 
46
46
 
@@ -237,8 +237,11 @@ async def download_binary(
237
237
  context: ApplicationContext, cf: resources_pb2.CloudFile
238
238
  ) -> AsyncGenerator[bytes, None]:
239
239
  bucket_name = context.blob_storage.get_bucket_name_from_cf(cf)
240
+ downloaded_bytes = 0
240
241
  async for data in context.blob_storage.download(bucket_name, cf.uri):
241
242
  yield data
243
+ downloaded_bytes += len(data)
244
+ assert downloaded_bytes == cf.size, "Downloaded bytes do not match the expected size"
242
245
 
243
246
 
244
247
  async def get_entities(context: ApplicationContext, kbid: str) -> kb_pb2.EntitiesGroups:
@@ -416,6 +419,8 @@ class ExportStreamReader:
416
419
  class TaskRetryHandler:
417
420
  """
418
421
  Class that wraps an import/export task and adds retry logic to it.
422
+
423
+ TODO: This should be refactored to use generic task retry logic at tasks/retries.py::TaskRetryHandler
419
424
  """
420
425
 
421
426
  def __init__(
nucliadb/ingest/app.py CHANGED
@@ -22,6 +22,7 @@ import importlib.metadata
22
22
  from typing import Awaitable, Callable
23
23
 
24
24
  from nucliadb import health
25
+ from nucliadb.backups.tasks import initialize_consumers as initialize_backup_consumers
25
26
  from nucliadb.common.cluster.utils import setup_cluster, teardown_cluster
26
27
  from nucliadb.common.context import ApplicationContext
27
28
  from nucliadb.common.nidx import start_nidx_utility
@@ -154,6 +155,7 @@ async def main_subscriber_workers(): # pragma: no cover
154
155
  await exports_consumer.initialize(context)
155
156
  imports_consumer = get_imports_consumer()
156
157
  await imports_consumer.initialize(context)
158
+ backup_consumers_finalizers = await initialize_backup_consumers(context)
157
159
 
158
160
  await run_until_exit(
159
161
  [
@@ -165,7 +167,10 @@ async def main_subscriber_workers(): # pragma: no cover
165
167
  metrics_server.shutdown,
166
168
  grpc_health_finalizer,
167
169
  context.finalize,
170
+ exports_consumer.finalize,
171
+ imports_consumer.finalize,
168
172
  ]
173
+ + backup_consumers_finalizers
169
174
  + finalizers
170
175
  )
171
176
 
@@ -216,6 +221,7 @@ def run_subscriber_workers() -> None: # pragma: no cover
216
221
  - audit fields subscriber
217
222
  - export/import subscriber
218
223
  - materializer subscriber
224
+ - backups subscribers
219
225
  """
220
226
  setup_configuration()
221
227
  asyncio.run(main_subscriber_workers())
@@ -28,4 +28,4 @@ class InvalidPBClass(Exception):
28
28
  def __init__(self, source: Type, destination: Type):
29
29
  self.source = source
30
30
  self.destination = destination
31
- super().__init__("Source and destination does not match " f"{self.source} - {self.destination}")
31
+ super().__init__(f"Source and destination does not match {self.source} - {self.destination}")
@@ -858,9 +858,9 @@ class Resource:
858
858
  for field_vectors in fields_vectors:
859
859
  # Bw/c with extracted vectors without vectorsets
860
860
  if not field_vectors.vectorset_id:
861
- assert (
862
- len(vectorsets) == 1
863
- ), "Invalid broker message, can't ingest vectors from unknown vectorset to KB with multiple vectorsets"
861
+ assert len(vectorsets) == 1, (
862
+ "Invalid broker message, can't ingest vectors from unknown vectorset to KB with multiple vectorsets"
863
+ )
864
864
  vectorset = list(vectorsets.values())[0]
865
865
 
866
866
  else:
@@ -477,9 +477,9 @@ class ProcessingEngine:
477
477
 
478
478
 
479
479
  class DummyProcessingEngine(ProcessingEngine):
480
- def __init__(self):
480
+ def __init__(self: "DummyProcessingEngine"):
481
481
  self.calls: list[list[Any]] = []
482
- self.values = defaultdict(list)
482
+ self.values: dict[str, Any] = defaultdict(list)
483
483
  self.onprem = True
484
484
 
485
485
  async def initialize(self):
@@ -189,7 +189,7 @@ async def managed_serialize(
189
189
 
190
190
  include_values = ResourceProperties.VALUES in show
191
191
 
192
- include_extracted_data = ResourceProperties.EXTRACTED in show and extracted is not []
192
+ include_extracted_data = ResourceProperties.EXTRACTED in show and extracted != []
193
193
 
194
194
  if ResourceProperties.BASIC in show:
195
195
  await orm_resource.get_basic()