nucliadb-utils 4.0.3.post577__py3-none-any.whl → 4.0.3.post579__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -51,7 +51,6 @@ http_settings = HTTPSettings()
51
51
  class FileBackendConfig(Enum):
52
52
  GCS = "gcs"
53
53
  S3 = "s3"
54
- PG = "pg"
55
54
  LOCAL = "local"
56
55
  NOT_SET = "notset" # setting not provided
57
56
 
@@ -26,7 +26,7 @@ import socket
26
26
  from concurrent.futures import ThreadPoolExecutor
27
27
  from copy import deepcopy
28
28
  from datetime import datetime
29
- from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional
29
+ from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, cast
30
30
  from urllib.parse import quote_plus
31
31
 
32
32
  import aiohttp
@@ -50,6 +50,7 @@ from nucliadb_utils.storages.exceptions import (
50
50
  from nucliadb_utils.storages.storage import (
51
51
  ObjectInfo,
52
52
  ObjectMetadata,
53
+ Range,
53
54
  Storage,
54
55
  StorageField,
55
56
  )
@@ -162,11 +163,13 @@ class GCSStorageField(StorageField):
162
163
  assert data["resource"]["name"] == destination_uri
163
164
 
164
165
  @storage_ops_observer.wrap({"type": "iter_data"})
165
- async def iter_data(self, headers=None):
166
+ async def iter_data(
167
+ self, range: Optional[Range] = None
168
+ ) -> AsyncGenerator[bytes, None]:
166
169
  attempt = 1
167
170
  while True:
168
171
  try:
169
- async for chunk in self._inner_iter_data(headers=headers):
172
+ async for chunk in self._inner_iter_data(range=range):
170
173
  yield chunk
171
174
  break
172
175
  except ReadingResponseContentException:
@@ -185,23 +188,26 @@ class GCSStorageField(StorageField):
185
188
  attempt += 1
186
189
 
187
190
  @storage_ops_observer.wrap({"type": "inner_iter_data"})
188
- async def _inner_iter_data(self, headers=None):
189
- if headers is None:
190
- headers = {}
191
+ async def _inner_iter_data(self, range: Optional[Range] = None):
192
+ """
193
+ Iterate through object data.
194
+ """
195
+ range = range or Range()
196
+ assert self.storage.session is not None
191
197
 
198
+ headers = await self.storage.get_access_headers()
199
+ if range.any():
200
+ headers["Range"] = range.to_header()
192
201
  key = self.field.uri if self.field else self.key
193
202
  if self.field is None:
194
203
  bucket = self.bucket
195
204
  else:
196
205
  bucket = self.field.bucket_name
197
-
198
206
  url = "{}/{}/o/{}".format(
199
207
  self.storage.object_base_url,
200
208
  bucket,
201
209
  quote_plus(key),
202
210
  )
203
- headers.update(await self.storage.get_access_headers())
204
-
205
211
  async with self.storage.session.get(
206
212
  url, headers=headers, params={"alt": "media"}, timeout=-1
207
213
  ) as api_resp:
@@ -209,11 +215,6 @@ class GCSStorageField(StorageField):
209
215
  text = await api_resp.text()
210
216
  if api_resp.status == 404:
211
217
  raise KeyError(f"Google cloud file not found : \n {text}")
212
- elif api_resp.status == 401:
213
- logger.warning(f"Invalid google cloud credentials error: {text}")
214
- raise KeyError(
215
- content={f"Google cloud invalid credentials : \n {text}"}
216
- )
217
218
  raise GoogleCloudException(f"{api_resp.status}: {text}")
218
219
  while True:
219
220
  try:
@@ -225,16 +226,6 @@ class GCSStorageField(StorageField):
225
226
  else:
226
227
  break
227
228
 
228
- @storage_ops_observer.wrap({"type": "read_range"})
229
- async def read_range(self, start: int, end: int) -> AsyncGenerator[bytes, None]:
230
- """
231
- Iterate through ranges of data
232
- """
233
- async for chunk in self.iter_data(
234
- headers={"Range": f"bytes={start}-{end - 1}"}
235
- ):
236
- yield chunk
237
-
238
229
  @backoff.on_exception(
239
230
  backoff.expo,
240
231
  RETRIABLE_EXCEPTIONS,
@@ -443,18 +434,8 @@ class GCSStorageField(StorageField):
443
434
  async with self.storage.session.get(url, headers=headers) as api_resp:
444
435
  if api_resp.status == 200:
445
436
  data = await api_resp.json()
446
- metadata = data.get("metadata") or {}
447
- metadata = {k.lower(): v for k, v in metadata.items()}
448
- size = metadata.get("size") or data.get("size") or 0
449
- content_type = (
450
- metadata.get("content_type") or data.get("contentType") or ""
451
- )
452
- filename = metadata.get("filename") or key.split("/")[-1]
453
- return ObjectMetadata(
454
- filename=filename,
455
- size=int(size),
456
- content_type=content_type,
457
- )
437
+ data = cast(dict[str, Any], data)
438
+ return parse_object_metadata(data, key)
458
439
  else:
459
440
  return None
460
441
 
@@ -758,3 +739,31 @@ class GCSStorage(Storage):
758
739
  for item in items:
759
740
  yield ObjectInfo(name=item["name"])
760
741
  page_token = data.get("nextPageToken")
742
+
743
+
744
+ def parse_object_metadata(object_data: dict[str, Any], key: str) -> ObjectMetadata:
745
+ custom_metadata: dict[str, str] = object_data.get("metadata") or {}
746
+ # Lowercase all keys for backwards compatibility with old custom metadata
747
+ custom_metadata = {k.lower(): v for k, v in custom_metadata.items()}
748
+
749
+ # Parse size
750
+ custom_size = custom_metadata.get("size")
751
+ if not custom_size or custom_size == "0":
752
+ data_size = object_data.get("size")
753
+ size = int(data_size) if data_size else 0
754
+ else:
755
+ size = int(custom_size)
756
+
757
+ # Parse content-type
758
+ content_type = (
759
+ custom_metadata.get("content_type") or object_data.get("contentType") or ""
760
+ )
761
+
762
+ # Parse filename
763
+ filename = custom_metadata.get("filename") or key.split("/")[-1]
764
+
765
+ return ObjectMetadata(
766
+ filename=filename,
767
+ size=int(size),
768
+ content_type=content_type,
769
+ )
@@ -24,7 +24,7 @@ import json
24
24
  import os
25
25
  import shutil
26
26
  from datetime import datetime
27
- from typing import AsyncGenerator, AsyncIterator, Dict, Optional
27
+ from typing import AsyncGenerator, AsyncIterator, Optional
28
28
 
29
29
  import aiofiles
30
30
  from nucliadb_protos.resources_pb2 import CloudFile
@@ -33,6 +33,7 @@ from nucliadb_utils.storages import CHUNK_SIZE
33
33
  from nucliadb_utils.storages.storage import (
34
34
  ObjectInfo,
35
35
  ObjectMetadata,
36
+ Range,
36
37
  Storage,
37
38
  StorageField,
38
39
  )
@@ -77,7 +78,9 @@ class LocalStorageField(StorageField):
77
78
  destination_path = f"{destination_bucket_path}/{destination_uri}"
78
79
  shutil.copy(origin_path, destination_path)
79
80
 
80
- async def iter_data(self, headers=None):
81
+ async def iter_data(
82
+ self, range: Optional[Range] = None
83
+ ) -> AsyncGenerator[bytes, None]:
81
84
  key = self.field.uri if self.field else self.key
82
85
  if self.field is None:
83
86
  bucket = self.bucket
@@ -86,34 +89,36 @@ class LocalStorageField(StorageField):
86
89
 
87
90
  path = self.storage.get_file_path(bucket, key)
88
91
  async with aiofiles.open(path, mode="rb") as resp:
92
+
93
+ if range and range.start is not None:
94
+ # Seek to the start of the range
95
+ await resp.seek(range.start)
96
+
97
+ bytes_read = 0
98
+ bytes_to_read = None # If None, read until EOF
99
+ if range and range.end is not None:
100
+ # Range is inclusive
101
+ bytes_to_read = range.end - (range.start or 0) + 1
102
+
89
103
  while True:
90
- data = await resp.read(CHUNK_SIZE)
91
- if not data:
104
+ chunk_size = CHUNK_SIZE
105
+ if bytes_to_read is not None:
106
+ if bytes_read >= bytes_to_read:
107
+ # Reached the end of the range
108
+ break
109
+ chunk_size = min(CHUNK_SIZE, bytes_to_read)
110
+
111
+ if chunk_size <= 0:
112
+ # No more data to read
92
113
  break
93
- yield data
94
114
 
95
- async def read_range(self, start: int, end: int) -> AsyncGenerator[bytes, None]:
96
- """
97
- Iterate through ranges of data
98
- """
99
- key = self.field.uri if self.field else self.key
100
- if self.field is None:
101
- bucket = self.bucket
102
- else:
103
- bucket = self.field.bucket_name
115
+ data = await resp.read(chunk_size)
116
+ if not data:
117
+ # EOF
118
+ break
104
119
 
105
- path = self.storage.get_file_path(bucket, key)
106
- async with aiofiles.open(path, "rb") as resp:
107
- await resp.seek(start)
108
- count = 0
109
- data = await resp.read(CHUNK_SIZE)
110
- while data and count < end:
111
- if count + len(data) > end:
112
- new_end = end - count
113
- data = data[:new_end]
114
120
  yield data
115
- count += len(data)
116
- data = await resp.read(CHUNK_SIZE)
121
+ bytes_read += len(data)
117
122
 
118
123
  async def start(self, cf: CloudFile) -> CloudFile:
119
124
  if self.field is not None and self.field.upload_uri != "":
@@ -285,17 +290,9 @@ class LocalStorage(Storage):
285
290
  for key in glob.glob(f"{bucket}/{prefix}*"):
286
291
  yield ObjectInfo(name=key)
287
292
 
288
- async def download(
289
- self, bucket_name: str, key: str, headers: Optional[Dict[str, str]] = None
290
- ):
293
+ async def download(self, bucket_name: str, key: str, range: Optional[Range] = None):
291
294
  key_path = self.get_file_path(bucket_name, key)
292
295
  if not os.path.exists(key_path):
293
296
  return
294
-
295
- async with aiofiles.open(key_path, mode="rb") as f:
296
- while True:
297
- body = await f.read(self.chunk_size)
298
- if body == b"" or body is None:
299
- break
300
- else:
301
- yield body
297
+ async for chunk in super().download(bucket_name, key, range=range):
298
+ yield chunk
@@ -37,6 +37,7 @@ from nucliadb_utils.storages.exceptions import UnparsableResponse
37
37
  from nucliadb_utils.storages.storage import (
38
38
  ObjectInfo,
39
39
  ObjectMetadata,
40
+ Range,
40
41
  Storage,
41
42
  StorageField,
42
43
  )
@@ -81,15 +82,21 @@ class S3StorageField(StorageField):
81
82
  jitter=backoff.random_jitter,
82
83
  max_tries=MAX_TRIES,
83
84
  )
84
- async def _download(self, uri, bucket, **kwargs):
85
- if "headers" in kwargs:
86
- for key, value in kwargs["headers"].items():
87
- kwargs[key] = value
88
- del kwargs["headers"]
89
- try:
90
- return await self.storage._s3aioclient.get_object(
91
- Bucket=bucket, Key=uri, **kwargs
85
+ async def _download(
86
+ self,
87
+ uri,
88
+ bucket,
89
+ range: Optional[Range] = None,
90
+ ):
91
+ range = range or Range()
92
+ if range.any():
93
+ coro = self.storage._s3aioclient.get_object(
94
+ Bucket=bucket, Key=uri, Range=range.to_header()
92
95
  )
96
+ else:
97
+ coro = self.storage._s3aioclient.get_object(Bucket=bucket, Key=uri)
98
+ try:
99
+ return await coro
93
100
  except botocore.exceptions.ClientError as e:
94
101
  error_code = parse_status_code(e)
95
102
  if error_code == 404:
@@ -97,18 +104,16 @@ class S3StorageField(StorageField):
97
104
  else:
98
105
  raise
99
106
 
100
- async def iter_data(self, **kwargs):
107
+ async def iter_data(
108
+ self, range: Optional[Range] = None
109
+ ) -> AsyncGenerator[bytes, None]:
101
110
  # Suports field and key based iter
102
111
  uri = self.field.uri if self.field else self.key
103
112
  if self.field is None:
104
113
  bucket = self.bucket
105
114
  else:
106
115
  bucket = self.field.bucket_name
107
-
108
- downloader = await self._download(uri, bucket, **kwargs)
109
-
110
- # we do not want to timeout ever from this...
111
- # downloader['Body'].set_socket_timeout(999999)
116
+ downloader = await self._download(uri, bucket, range=range)
112
117
  stream = downloader["Body"]
113
118
  data = await stream.read(CHUNK_SIZE)
114
119
  while True:
@@ -117,13 +122,6 @@ class S3StorageField(StorageField):
117
122
  yield data
118
123
  data = await stream.read(CHUNK_SIZE)
119
124
 
120
- async def read_range(self, start: int, end: int) -> AsyncGenerator[bytes, None]:
121
- """
122
- Iterate through ranges of data
123
- """
124
- async for chunk in self.iter_data(Range=f"bytes={start}-{end - 1}"):
125
- yield chunk
126
-
127
125
  async def _abort_multipart(self):
128
126
  try:
129
127
  mpu = self.field.resumable_uri
@@ -296,18 +294,9 @@ class S3StorageField(StorageField):
296
294
 
297
295
  try:
298
296
  obj = await self.storage._s3aioclient.head_object(Bucket=bucket, Key=key)
299
- if obj is not None:
300
- metadata = obj.get("Metadata") or {}
301
- size = metadata.get("size") or obj.get("ContentLength") or 0
302
- content_type = (
303
- metadata.get("content_type") or obj.get("ContentType") or ""
304
- )
305
- filename = metadata.get("filename") or key.split("/")[-1]
306
- return ObjectMetadata(
307
- size=int(size), content_type=content_type, filename=filename
308
- )
309
- else:
297
+ if obj is None:
310
298
  return None
299
+ return parse_object_metadata(obj, key)
311
300
  except botocore.exceptions.ClientError as e:
312
301
  error_code = parse_status_code(e)
313
302
  if error_code == 404:
@@ -560,3 +549,21 @@ def parse_status_code(error: botocore.exceptions.ClientError) -> int:
560
549
  errors.capture_message(msg, "error", scope)
561
550
 
562
551
  raise UnparsableResponse(msg) from error
552
+
553
+
554
+ def parse_object_metadata(obj: dict, key: str) -> ObjectMetadata:
555
+ custom_metadata = obj.get("Metadata") or {}
556
+ # Parse size
557
+ custom_size = custom_metadata.get("size")
558
+ if custom_size is None or custom_size == "0":
559
+ size = 0
560
+ content_lenght = obj.get("ContentLength")
561
+ if content_lenght is not None:
562
+ size = int(content_lenght)
563
+ else:
564
+ size = int(custom_size)
565
+ # Content type
566
+ content_type = custom_metadata.get("content_type") or obj.get("ContentType") or ""
567
+ # Filename
568
+ filename = custom_metadata.get("filename") or key.split("/")[-1]
569
+ return ObjectMetadata(size=size, content_type=content_type, filename=filename)
@@ -22,12 +22,12 @@ from __future__ import annotations
22
22
  import abc
23
23
  import hashlib
24
24
  import uuid
25
+ from dataclasses import dataclass
25
26
  from io import BytesIO
26
27
  from typing import (
27
28
  Any,
28
29
  AsyncGenerator,
29
30
  AsyncIterator,
30
- Dict,
31
31
  List,
32
32
  Optional,
33
33
  Tuple,
@@ -71,6 +71,23 @@ class ObjectMetadata(BaseModel):
71
71
  size: int
72
72
 
73
73
 
74
+ @dataclass
75
+ class Range:
76
+ """
77
+ Represents a range of bytes to be downloaded from a file. The range is inclusive.
78
+ The start and end values are 0-based.
79
+ """
80
+
81
+ start: Optional[int] = None
82
+ end: Optional[int] = None
83
+
84
+ def any(self) -> bool:
85
+ return self.start is not None or self.end is not None
86
+
87
+ def to_header(self) -> str:
88
+ return f"bytes={self.start or 0}-{self.end or ''}"
89
+
90
+
74
91
  class StorageField(abc.ABC, metaclass=abc.ABCMeta):
75
92
  storage: Storage
76
93
  bucket: str
@@ -93,12 +110,9 @@ class StorageField(abc.ABC, metaclass=abc.ABCMeta):
93
110
  async def upload(self, iterator: AsyncIterator, origin: CloudFile) -> CloudFile: ...
94
111
 
95
112
  @abc.abstractmethod
96
- async def iter_data(self, headers=None) -> AsyncGenerator[bytes, None]: # type: ignore
97
- raise NotImplementedError()
98
- yield b""
99
-
100
- @abc.abstractmethod
101
- async def read_range(self, start: int, end: int) -> AsyncGenerator[bytes, None]:
113
+ async def iter_data(
114
+ self, range: Optional[Range] = None
115
+ ) -> AsyncGenerator[bytes, None]:
102
116
  raise NotImplementedError()
103
117
  yield b""
104
118
 
@@ -433,16 +447,16 @@ class Storage(abc.ABC, metaclass=abc.ABCMeta):
433
447
  return await destination.upload(safe_iterator, origin)
434
448
 
435
449
  async def download(
436
- self, bucket: str, key: str, headers: Optional[Dict[str, str]] = None
450
+ self,
451
+ bucket: str,
452
+ key: str,
453
+ range: Optional[Range] = None,
437
454
  ):
438
455
  destination: StorageField = self.field_klass(
439
456
  storage=self, bucket=bucket, fullkey=key
440
457
  )
441
- if headers is None:
442
- headers = {}
443
-
444
458
  try:
445
- async for data in destination.iter_data(headers=headers):
459
+ async for data in destination.iter_data(range=range):
446
460
  yield data
447
461
  except KeyError:
448
462
  yield None
@@ -33,8 +33,8 @@ def lazy_storage_fixture():
33
33
  return [lazy_fixture.lf("gcs_storage")]
34
34
  elif backend == "s3":
35
35
  return [lazy_fixture.lf("s3_storage")]
36
- elif backend == "pg":
37
- return [lazy_fixture.lf("pg_storage")]
36
+ elif backend == "local":
37
+ return [lazy_fixture.lf("local_storage")]
38
38
  else:
39
39
  print(f"Unknown storage backend {backend}, using gcs")
40
40
  return [lazy_fixture.lf("gcs_storage")]
@@ -138,17 +138,6 @@ async def get_storage(
138
138
  await gcsutil.initialize(service_name)
139
139
  logger.info("Configuring GCS Storage")
140
140
 
141
- elif storage_settings.file_backend == FileBackendConfig.PG:
142
- from nucliadb_utils.storages.pg import PostgresStorage
143
-
144
- pgutil = PostgresStorage(
145
- storage_settings.driver_pg_url, # type: ignore
146
- connection_pool_max_size=storage_settings.driver_pg_connection_pool_max_size,
147
- )
148
- set_utility(Utility.STORAGE, pgutil)
149
- await pgutil.initialize()
150
- logger.info("Configuring Postgres Storage")
151
-
152
141
  elif storage_settings.file_backend == FileBackendConfig.LOCAL:
153
142
  if storage_settings.local_files is None:
154
143
  raise ConfigurationError("LOCAL_FILES env var not configured")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: nucliadb_utils
3
- Version: 4.0.3.post577
3
+ Version: 4.0.3.post579
4
4
  Home-page: https://nuclia.com
5
5
  License: BSD
6
6
  Classifier: Development Status :: 4 - Beta
@@ -23,8 +23,8 @@ Requires-Dist: PyNaCl
23
23
  Requires-Dist: pyjwt >=2.4.0
24
24
  Requires-Dist: memorylru >=1.1.2
25
25
  Requires-Dist: mrflagly
26
- Requires-Dist: nucliadb-protos >=4.0.3.post577
27
- Requires-Dist: nucliadb-telemetry >=4.0.3.post577
26
+ Requires-Dist: nucliadb-protos >=4.0.3.post579
27
+ Requires-Dist: nucliadb-telemetry >=4.0.3.post579
28
28
  Provides-Extra: cache
29
29
  Requires-Dist: redis >=4.3.4 ; extra == 'cache'
30
30
  Requires-Dist: orjson >=3.6.7 ; extra == 'cache'
@@ -12,11 +12,11 @@ nucliadb_utils/nats.py,sha256=7hRKMflwxK-p_L0KFO5jibOGhzSw2F24mKvPG-A_iN8,8224
12
12
  nucliadb_utils/partition.py,sha256=0tmXuwRM_v-OmKoRkB--OMDzomVEkmgxqMmNNomI_24,1173
13
13
  nucliadb_utils/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  nucliadb_utils/run.py,sha256=bKMfsPEK6WdWfiPyWPUxCqcLo4tq6eOwyaf910TOwBk,1713
15
- nucliadb_utils/settings.py,sha256=Dd6h_BavMkt3e8qA3btZjt2si11x7tvpOC_WYMxqrDM,7252
15
+ nucliadb_utils/settings.py,sha256=WVL2u_jCkm7Uf6a2njOZetHM_nU0hwDVhLqfH0k5Yi4,7238
16
16
  nucliadb_utils/signals.py,sha256=5r53hZvZmwgKdri5jHEjuHmiaq5TyusUUvjoq2uliIc,2704
17
17
  nucliadb_utils/store.py,sha256=kQ35HemE0v4_Qg6xVqNIJi8vSFAYQtwI3rDtMsNy62Y,890
18
18
  nucliadb_utils/transaction.py,sha256=CQpsuF-E2omh4gGMxXCn0dv7vL9ctxooWpSgWGbGfBA,7212
19
- nucliadb_utils/utilities.py,sha256=s5MVXDj4DTtc1VPFBRxMjud3HB0xkadyQ0f7QQLb0NM,14178
19
+ nucliadb_utils/utilities.py,sha256=mbwU6BWoUNoTiWDeUqWK-VCtOM8ZFAyPz0SX7Q818PY,13724
20
20
  nucliadb_utils/audit/__init__.py,sha256=cp15ZcFnHvpcu_5-aK2A4uUyvuZVV_MJn4bIXMa20ks,835
21
21
  nucliadb_utils/audit/audit.py,sha256=fmEVb6ahKrkGAY-GEy4_L4ccmcGM5YKl-Vs05260_cg,2834
22
22
  nucliadb_utils/audit/basic.py,sha256=8yL7HI9MnykSt7j4QbUeRBbBsTKFIIX6hppJ3ADVLdM,3430
@@ -40,24 +40,22 @@ nucliadb_utils/nuclia_usage/utils/__init__.py,sha256=cp15ZcFnHvpcu_5-aK2A4uUyvuZ
40
40
  nucliadb_utils/nuclia_usage/utils/kb_usage_report.py,sha256=E1eUSFXBVNzQP9Q2rWj9y3koCO5S7iKwckny_AoLKuk,3870
41
41
  nucliadb_utils/storages/__init__.py,sha256=5Qc8AUWiJv9_JbGCBpAn88AIJhwDlm0OPQpg2ZdRL4U,872
42
42
  nucliadb_utils/storages/exceptions.py,sha256=n6aBOyurWMo8mXd1XY6Psgno4VfXJ9TRbxCy67c08-g,2417
43
- nucliadb_utils/storages/gcs.py,sha256=krBkNd7wkHhfIn3T-4QvYu1Rw-envYCa6G4G90oOjvM,27303
44
- nucliadb_utils/storages/local.py,sha256=JewYQ-fes9iUtUjlbHgWXrG1RsQWh16TJDunJnwfbTg,10447
43
+ nucliadb_utils/storages/gcs.py,sha256=JcIL9gQ1YCXtNkuEhFciP_VcgyWcy4e4xuN01d2eZIg,27372
44
+ nucliadb_utils/storages/local.py,sha256=nDrmWy1na96AS__hO3TQqsYMHnu0buwnfUGWfxCpWYU,10348
45
45
  nucliadb_utils/storages/nuclia.py,sha256=UfvRu92eqG1v-PE-UWH2x8KEJFqDqATMmUGFmEuqSSs,2097
46
- nucliadb_utils/storages/pg.py,sha256=DxXNwcstAFOTC6kaXlWp-b4WrvR8aSSOfgVJNDQ5oDI,18976
47
- nucliadb_utils/storages/s3.py,sha256=f2bjgmT6JRlUr5DHy3tRUip4kYSA1MzXfYrLNVUp_Cg,19447
46
+ nucliadb_utils/storages/s3.py,sha256=RRbcYr4FE-Vfisr-zPoUN0Q_LfEHF-L2B0ggFuVsOwU,19500
48
47
  nucliadb_utils/storages/settings.py,sha256=ugCPy1zxBOmA2KosT-4tsjpvP002kg5iQyi42yCGCJA,1285
49
- nucliadb_utils/storages/storage.py,sha256=sR2Qvev6eLUvbH1WTXjqXIOnKRy1YMMx6Vsj0wZ2x8A,20585
48
+ nucliadb_utils/storages/storage.py,sha256=KJ5VDYoZuRmiCFwfLj__tDOHIJWyQAMi-sOXCMoJv9w,20831
50
49
  nucliadb_utils/tests/__init__.py,sha256=Oo9CAE7B0eW5VHn8sHd6o30SQzOWUhktLPRXdlDOleA,1456
51
50
  nucliadb_utils/tests/asyncbenchmark.py,sha256=rN_NNDk4ras0qgFp0QlRyAi9ZU9xITdzxl2s5CigzBo,10698
52
- nucliadb_utils/tests/fixtures.py,sha256=ZvKaxZFMULC2Sbo0jSIuGxJW_cgVH_pNjhVYo9PbgyA,1665
51
+ nucliadb_utils/tests/fixtures.py,sha256=j58fTvoWZClC52LX7QOvLXX9DS5QbytSnRp0F4nGzN8,1671
53
52
  nucliadb_utils/tests/gcs.py,sha256=1dbt_zG3uZPZDF3Nyrgrvi_bsKmafAUOm4Pu4bzt7wI,3098
54
53
  nucliadb_utils/tests/indexing.py,sha256=YW2QhkhO9Q_8A4kKWJaWSvXvyQ_AiAwY1VylcfVQFxk,1513
55
54
  nucliadb_utils/tests/local.py,sha256=c3gZJJWmvOftruJkIQIwB3q_hh3uxEhqGIAVWim1Bbk,1343
56
55
  nucliadb_utils/tests/nats.py,sha256=lgRe6YH9LSoI7XgcyKAC2VTSAtuu8EeMve0jWWC_kOY,7701
57
- nucliadb_utils/tests/pg.py,sha256=HBpvaNDs9T_L55tvJCJTPnsCDrB8ehI_9HRYz6SPWNE,1819
58
56
  nucliadb_utils/tests/s3.py,sha256=YB8QqDaBXxyhHonEHmeBbRRDmvB7sTOaKBSi8KBGokg,2330
59
- nucliadb_utils-4.0.3.post577.dist-info/METADATA,sha256=prh_IQhq0bc4fqgqeIUWgedwV_OyVzxFhH0_xSa1_Y0,2030
60
- nucliadb_utils-4.0.3.post577.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
61
- nucliadb_utils-4.0.3.post577.dist-info/top_level.txt,sha256=fE3vJtALTfgh7bcAWcNhcfXkNPp_eVVpbKK-2IYua3E,15
62
- nucliadb_utils-4.0.3.post577.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
63
- nucliadb_utils-4.0.3.post577.dist-info/RECORD,,
57
+ nucliadb_utils-4.0.3.post579.dist-info/METADATA,sha256=5OY05fts98E0YRW0yVwWXMwjWbcJwxlw8KU50Q2FVNI,2030
58
+ nucliadb_utils-4.0.3.post579.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
59
+ nucliadb_utils-4.0.3.post579.dist-info/top_level.txt,sha256=fE3vJtALTfgh7bcAWcNhcfXkNPp_eVVpbKK-2IYua3E,15
60
+ nucliadb_utils-4.0.3.post579.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
61
+ nucliadb_utils-4.0.3.post579.dist-info/RECORD,,
@@ -1,617 +0,0 @@
1
- # Copyright (C) 2021 Bosutech XXI S.L.
2
- #
3
- # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
- # For commercial licensing, contact us at info@nuclia.com.
5
- #
6
- # AGPL:
7
- # This program is free software: you can redistribute it and/or modify
8
- # it under the terms of the GNU Affero General Public License as
9
- # published by the Free Software Foundation, either version 3 of the
10
- # License, or (at your option) any later version.
11
- #
12
- # This program is distributed in the hope that it will be useful,
13
- # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
- # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
- # GNU Affero General Public License for more details.
16
- #
17
- # You should have received a copy of the GNU Affero General Public License
18
- # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
- #
20
- from __future__ import annotations
21
-
22
- import asyncio
23
- import logging
24
- import uuid
25
- from typing import Any, AsyncGenerator, AsyncIterator, Optional, TypedDict
26
-
27
- import asyncpg
28
- from nucliadb_protos.resources_pb2 import CloudFile
29
-
30
- from nucliadb_utils.storages import CHUNK_SIZE
31
- from nucliadb_utils.storages.storage import (
32
- ObjectInfo,
33
- ObjectMetadata,
34
- Storage,
35
- StorageField,
36
- )
37
-
38
- logger = logging.getLogger(__name__)
39
-
40
- # Table design notes
41
- # - No foreign key constraints ON PURPOSE
42
- # - No cascade handling ON PURPOSE
43
- CREATE_TABLE = """
44
- CREATE TABLE IF NOT EXISTS kb_files (
45
- kb_id TEXT,
46
- file_id TEXT,
47
- filename TEXT,
48
- size INTEGER,
49
- content_type TEXT,
50
- PRIMARY KEY(kb_id, file_id)
51
- );
52
-
53
- CREATE TABLE IF NOT EXISTS kb_files_fileparts (
54
- kb_id TEXT,
55
- file_id TEXT,
56
- part_id INTEGER,
57
- size INTEGER,
58
- data BYTEA,
59
- PRIMARY KEY(kb_id, file_id, part_id)
60
- );
61
- """
62
-
63
-
64
- class FileInfo(TypedDict):
65
- filename: str
66
- size: int
67
- content_type: str
68
- key: str
69
-
70
-
71
- class ChunkInfo(TypedDict):
72
- part_id: int
73
- size: int
74
-
75
-
76
- class Chunk(ChunkInfo):
77
- data: bytes
78
-
79
-
80
- class PostgresFileDataLayer:
81
- """
82
- Responsible for interating with the database and
83
- abstracting any sql and connection management.
84
- """
85
-
86
- def __init__(self, connection: asyncpg.Connection):
87
- self.connection = connection
88
-
89
- async def initialize_kb(self, kbid: str) -> bool:
90
- # there's really no record keeping or init
91
- # per kb that we care to do
92
- return True
93
-
94
- async def delete_kb(self, kbid: str) -> bool:
95
- async with self.connection.transaction():
96
- await self.connection.execute(
97
- """
98
- DELETE FROM kb_files
99
- WHERE kb_id = $1
100
- """,
101
- kbid,
102
- )
103
- await self.connection.execute(
104
- """
105
- DELETE FROM kb_files_fileparts
106
- WHERE kb_id = $1
107
- """,
108
- kbid,
109
- )
110
- return True
111
-
112
- async def create_file(
113
- self, *, kb_id: str, file_id: str, filename: str, size: int, content_type: str
114
- ) -> None:
115
- async with self.connection.transaction():
116
- await self.connection.execute(
117
- """
118
- INSERT INTO kb_files (kb_id, file_id, filename, size, content_type)
119
- VALUES ($1, $2, $3, $4, $5)
120
- """,
121
- kb_id,
122
- file_id,
123
- filename or "",
124
- size,
125
- content_type or "",
126
- )
127
-
128
- async def delete_file(self, kb_id: str, file_id: str) -> None:
129
- async with self.connection.transaction():
130
- await self.connection.execute(
131
- """
132
- DELETE FROM kb_files
133
- WHERE kb_id = $1 AND file_id = $2
134
- """,
135
- kb_id,
136
- file_id,
137
- )
138
- await self.connection.execute(
139
- """
140
- DELETE FROM kb_files_fileparts
141
- WHERE kb_id = $1 AND file_id = $2
142
- """,
143
- kb_id,
144
- file_id,
145
- )
146
-
147
- async def append_chunk(self, *, kb_id: str, file_id: str, data: bytes) -> None:
148
- async with self.connection.transaction():
149
- await self.connection.execute(
150
- """
151
- INSERT INTO kb_files_fileparts (kb_id, file_id, part_id, data, size)
152
- VALUES (
153
- $1, $2,
154
- (
155
- SELECT COALESCE(MAX(part_id), 0) + 1
156
- FROM kb_files_fileparts WHERE kb_id = $1 AND file_id = $2
157
- ),
158
- $3, $4)
159
- """,
160
- kb_id,
161
- file_id,
162
- data,
163
- len(data),
164
- )
165
-
166
- async def get_file_info(self, kb_id: str, file_id: str) -> Optional[FileInfo]:
167
- record = await self.connection.fetchrow(
168
- """
169
- SELECT filename, size, content_type, file_id
170
- FROM kb_files
171
- WHERE kb_id = $1 AND file_id = $2
172
- """,
173
- kb_id,
174
- file_id,
175
- )
176
- if record is None:
177
- return None
178
- return FileInfo(
179
- filename=record["filename"],
180
- size=record["size"],
181
- content_type=record["content_type"],
182
- key=record["file_id"],
183
- )
184
-
185
- async def move(
186
- self,
187
- *,
188
- origin_key: str,
189
- destination_key: str,
190
- origin_kb: str,
191
- destination_kb: str,
192
- ):
193
- async with self.connection.transaction():
194
- # make sure to delete the destination first in
195
- # case this is an overwrite of an existing
196
- await self.connection.execute(
197
- """
198
- delete from kb_files
199
- WHERE kb_id = $1 AND file_id = $2
200
- """,
201
- destination_kb,
202
- destination_key,
203
- )
204
- await self.connection.execute(
205
- """
206
- UPDATE kb_files
207
- SET kb_id = $1, file_id = $2
208
- WHERE kb_id = $3 AND file_id = $4
209
- """,
210
- destination_kb,
211
- destination_key,
212
- origin_kb,
213
- origin_key,
214
- )
215
- # make sure to delete the destination first in
216
- # case this is an overwrite of an existing
217
- await self.connection.execute(
218
- """
219
- delete from kb_files_fileparts
220
- WHERE kb_id = $1 AND file_id = $2
221
- """,
222
- destination_kb,
223
- destination_key,
224
- )
225
- await self.connection.execute(
226
- """
227
- UPDATE kb_files_fileparts
228
- SET kb_id = $1, file_id = $2
229
- WHERE kb_id = $3 AND file_id = $4
230
- """,
231
- destination_kb,
232
- destination_key,
233
- origin_kb,
234
- origin_key,
235
- )
236
-
237
- async def copy(
238
- self,
239
- *,
240
- origin_key: str,
241
- destination_key: str,
242
- origin_kb: str,
243
- destination_kb: str,
244
- ):
245
- async with self.connection.transaction():
246
- await self.connection.execute(
247
- """
248
- INSERT INTO kb_files (kb_id, file_id, filename, size, content_type)
249
- SELECT $1, $2, filename, size, content_type
250
- FROM kb_files
251
- WHERE kb_id = $3 AND file_id = $4
252
- """,
253
- destination_kb,
254
- destination_key,
255
- origin_kb,
256
- origin_key,
257
- )
258
-
259
- await self.connection.execute(
260
- """
261
- INSERT INTO kb_files_fileparts (kb_id, file_id, part_id, data, size)
262
- SELECT $1, $2, part_id, data, size
263
- FROM kb_files_fileparts
264
- WHERE kb_id = $3 AND file_id = $4
265
- """,
266
- destination_kb,
267
- destination_key,
268
- origin_kb,
269
- origin_key,
270
- )
271
-
272
- async def get_chunks_info(
273
- self, bucket: str, key: str, part_ids: Optional[list[int]] = None
274
- ) -> list[ChunkInfo]:
275
- query = """
276
- select kb_id, file_id, part_id, size
277
- from kb_files_fileparts
278
- where kb_id = $1 and file_id = $2
279
- """
280
- args: list[Any] = [bucket, key]
281
- if part_ids is not None:
282
- query += " and part_id = ANY($3)"
283
- args.append(part_ids)
284
- query += " order by part_id"
285
- chunks = await self.connection.fetch(query, *args)
286
- return [
287
- ChunkInfo(
288
- part_id=chunk["part_id"],
289
- size=chunk["size"],
290
- )
291
- for chunk in chunks
292
- ]
293
-
294
- async def iterate_kb(
295
- self, bucket: str, prefix: Optional[str] = None
296
- ) -> AsyncGenerator[FileInfo, None]:
297
- query = """
298
- SELECT filename, size, content_type, file_id
299
- FROM kb_files
300
- WHERE kb_id = $1
301
- """
302
- args: list[Any] = [bucket]
303
- if prefix:
304
- query += " AND filename LIKE $2"
305
- args.append(prefix + "%")
306
- async with self.connection.transaction():
307
- async for record in self.connection.cursor(query, *args):
308
- yield FileInfo(
309
- filename=record["filename"],
310
- size=record["size"],
311
- content_type=record["content_type"],
312
- key=record["file_id"],
313
- )
314
-
315
- async def iterate_chunks(
316
- self, bucket: str, key: str, part_ids: Optional[list[int]] = None
317
- ) -> AsyncIterator[Chunk]:
318
- chunks = await self.get_chunks_info(bucket, key, part_ids=part_ids)
319
- for chunk in chunks:
320
- # who knows how long a download for one of these chunks could be,
321
- # so let's not try to keep a txn or cursor open.
322
- data_chunk = await self.connection.fetchrow(
323
- """
324
- select data
325
- from kb_files_fileparts
326
- where kb_id = $1 and file_id = $2 and part_id = $3
327
- """,
328
- bucket,
329
- key,
330
- chunk["part_id"],
331
- )
332
- yield Chunk(
333
- part_id=chunk["part_id"],
334
- size=chunk["size"],
335
- data=data_chunk["data"],
336
- )
337
-
338
- async def iterate_range(
339
- self, *, kb_id: str, file_id: str, start: int, end: int
340
- ) -> AsyncIterator[bytes]:
341
- chunks = await self.get_chunks_info(
342
- kb_id,
343
- file_id,
344
- )
345
-
346
- # First off, find start part and position
347
- elapsed = 0
348
- start_part_id = None
349
- start_pos = -1
350
- for chunk in chunks:
351
- if elapsed + chunk["size"] > start:
352
- start_part_id = chunk["part_id"]
353
- start_pos = start - elapsed
354
- break
355
- else:
356
- elapsed += chunk["size"]
357
-
358
- if start_part_id is None:
359
- return
360
-
361
- # Now, iterate through the chunks and yield the data
362
- read_bytes = 0
363
- while read_bytes < end - start:
364
- data_chunk = await self.connection.fetchrow(
365
- """
366
- select data
367
- from kb_files_fileparts
368
- where kb_id = $1 and file_id = $2 and part_id = $3
369
- """,
370
- kb_id,
371
- file_id,
372
- start_part_id,
373
- )
374
- if data_chunk is None:
375
- return
376
-
377
- data = data_chunk["data"][
378
- start_pos : min(
379
- start_pos + ((end - start) - read_bytes), len(data_chunk["data"])
380
- )
381
- ]
382
- read_bytes += len(data)
383
- yield data
384
- start_pos = 0
385
- start_part_id += 1
386
-
387
-
388
- class PostgresStorageField(StorageField):
389
- storage: PostgresStorage
390
-
391
- async def move(
392
- self,
393
- origin_uri: str,
394
- destination_uri: str,
395
- origin_bucket_name: str,
396
- destination_bucket_name: str,
397
- ):
398
- async with self.storage.pool.acquire() as conn:
399
- dl = PostgresFileDataLayer(conn)
400
- return await dl.move(
401
- origin_key=origin_uri,
402
- destination_key=destination_uri,
403
- origin_kb=origin_bucket_name,
404
- destination_kb=destination_bucket_name,
405
- )
406
-
407
- async def copy(
408
- self,
409
- origin_uri: str,
410
- destination_uri: str,
411
- origin_bucket_name: str,
412
- destination_bucket_name: str,
413
- ):
414
- async with self.storage.pool.acquire() as conn:
415
- dl = PostgresFileDataLayer(conn)
416
- return await dl.copy(
417
- origin_key=origin_uri,
418
- destination_key=destination_uri,
419
- origin_kb=origin_bucket_name,
420
- destination_kb=destination_bucket_name,
421
- )
422
-
423
- async def iter_data(self, headers=None):
424
- key = self.field.uri if self.field else self.key
425
- if self.field is None:
426
- bucket = self.bucket
427
- else:
428
- bucket = self.field.bucket_name
429
-
430
- async with self.storage.pool.acquire() as conn:
431
- dl = PostgresFileDataLayer(conn)
432
- async for chunk in dl.iterate_chunks(bucket, key):
433
- yield chunk["data"]
434
-
435
- async def read_range(self, start: int, end: int) -> AsyncGenerator[bytes, None]:
436
- """
437
- Iterate through ranges of data
438
- """
439
- key = self.field.uri if self.field else self.key
440
- if self.field is None:
441
- bucket = self.bucket
442
- else:
443
- bucket = self.field.bucket_name
444
-
445
- async with self.storage.pool.acquire() as conn:
446
- dl = PostgresFileDataLayer(conn)
447
- async for data in dl.iterate_range(
448
- kb_id=bucket, file_id=key, start=start, end=end
449
- ):
450
- yield data
451
-
452
- async def start(self, cf: CloudFile) -> CloudFile:
453
- field = CloudFile(
454
- filename=cf.filename,
455
- size=cf.size,
456
- md5=cf.md5,
457
- content_type=cf.content_type,
458
- bucket_name=self.bucket,
459
- source=CloudFile.POSTGRES,
460
- )
461
- upload_uri = uuid.uuid4().hex
462
-
463
- async with self.storage.pool.acquire() as conn:
464
- async with conn.transaction():
465
- dl = PostgresFileDataLayer(conn)
466
-
467
- if self.field is not None and self.field.upload_uri != "":
468
- # If there is a temporal url
469
- await dl.delete_file(self.field.bucket_name, self.field.upload_uri)
470
-
471
- await dl.create_file(
472
- kb_id=self.bucket,
473
- file_id=upload_uri,
474
- filename=cf.filename,
475
- size=cf.size,
476
- content_type=cf.content_type,
477
- )
478
-
479
- field.offset = 0
480
- field.upload_uri = upload_uri
481
- return field
482
-
483
- async def append(self, cf: CloudFile, iterable: AsyncIterator) -> int:
484
- if self.field is None:
485
- raise AttributeError()
486
- count = 0
487
- async with self.storage.pool.acquire() as conn:
488
- dl = PostgresFileDataLayer(conn)
489
- async for chunk in iterable:
490
- await dl.append_chunk(
491
- kb_id=self.bucket,
492
- file_id=cf.upload_uri or self.field.upload_uri,
493
- data=chunk,
494
- )
495
- size = len(chunk)
496
- count += size
497
- self.field.offset += len(chunk)
498
- return count
499
-
500
- async def finish(self):
501
- async with self.storage.pool.acquire() as conn, conn.transaction():
502
- dl = PostgresFileDataLayer(conn)
503
- if self.field.old_uri not in ("", None):
504
- # Already has a file
505
- await dl.delete_file(self.bucket, self.field.uri)
506
-
507
- if self.field.upload_uri != self.key:
508
- try:
509
- await dl.move(
510
- origin_key=self.field.upload_uri,
511
- destination_key=self.key,
512
- origin_kb=self.field.bucket_name,
513
- destination_kb=self.bucket,
514
- )
515
- except Exception:
516
- logger.exception(
517
- f"Error moving file {self.field.bucket_name}://{self.field.upload_uri} -> {self.bucket}://{self.key}" # noqa
518
- )
519
- raise
520
-
521
- self.field.uri = self.key
522
- self.field.ClearField("offset")
523
- self.field.ClearField("upload_uri")
524
-
525
- async def exists(self) -> Optional[ObjectMetadata]:
526
- async with self.storage.pool.acquire() as conn:
527
- dl = PostgresFileDataLayer(conn)
528
- file_info = await dl.get_file_info(self.bucket, self.key)
529
- if file_info is None:
530
- return None
531
- return ObjectMetadata(
532
- filename=file_info["filename"],
533
- size=file_info["size"],
534
- content_type=file_info["content_type"],
535
- )
536
-
537
- async def upload(self, iterator: AsyncIterator, origin: CloudFile) -> CloudFile:
538
- self.field = await self.start(origin)
539
- await self.append(origin, iterator)
540
- await self.finish()
541
- return self.field
542
-
543
- def __repr__(self):
544
- return f"{self.storage.source}: {self.bucket}/{self.key}"
545
-
546
-
547
- class PostgresStorage(Storage):
548
- field_klass = PostgresStorageField
549
- chunk_size = CHUNK_SIZE
550
- pool: asyncpg.pool.Pool
551
-
552
- def __init__(self, dsn: str, connection_pool_max_size: int = 10):
553
- self.dsn = dsn
554
- self.connection_pool_max_size = connection_pool_max_size
555
- self.source = CloudFile.POSTGRES
556
- self._lock = asyncio.Lock()
557
- self.initialized = False
558
-
559
- async def initialize(self):
560
- async with self._lock:
561
- if self.initialized is False:
562
- self.pool = await asyncpg.create_pool(
563
- self.dsn,
564
- max_size=self.connection_pool_max_size,
565
- )
566
-
567
- # check if table exists
568
- try:
569
- async with self.pool.acquire() as conn:
570
- await conn.execute(CREATE_TABLE)
571
- except asyncpg.exceptions.UniqueViolationError: # pragma: no cover
572
- pass
573
-
574
- self.initialized = True
575
-
576
- async def finalize(self):
577
- async with self._lock:
578
- await self.pool.close()
579
- self.initialized = False
580
-
581
- def get_bucket_name(self, kbid: str):
582
- return kbid
583
-
584
- async def create_kb(self, kbid: str):
585
- async with self.pool.acquire() as conn:
586
- dl = PostgresFileDataLayer(conn)
587
- return await dl.initialize_kb(kbid)
588
-
589
- async def delete_kb(self, kbid: str) -> tuple[bool, bool]:
590
- async with self.pool.acquire() as conn:
591
- dl = PostgresFileDataLayer(conn)
592
- return await dl.delete_kb(kbid), False
593
-
594
- async def delete_upload(self, uri: str, bucket_name: str):
595
- async with self.pool.acquire() as conn:
596
- dl = PostgresFileDataLayer(conn)
597
- await dl.delete_file(bucket_name, uri)
598
-
599
- async def schedule_delete_kb(self, kbid: str) -> bool:
600
- await self.delete_kb(kbid)
601
- return True
602
-
603
- async def iterate_objects(
604
- self, bucket: str, prefix: str
605
- ) -> AsyncGenerator[ObjectInfo, None]:
606
- async with self.pool.acquire() as conn:
607
- dl = PostgresFileDataLayer(conn)
608
- async for file_data in dl.iterate_kb(bucket, prefix):
609
- yield ObjectInfo(name=file_data["key"])
610
-
611
- async def download(
612
- self, bucket_name: str, key: str, headers: Optional[dict[str, str]] = None
613
- ) -> AsyncIterator[bytes]:
614
- async with self.pool.acquire() as conn:
615
- dl = PostgresFileDataLayer(conn)
616
- async for chunk in dl.iterate_chunks(bucket_name, key):
617
- yield chunk["data"]
@@ -1,57 +0,0 @@
1
- # Copyright (C) 2021 Bosutech XXI S.L.
2
- #
3
- # nucliadb is offered under the AGPL v3.0 and as commercial software.
4
- # For commercial licensing, contact us at info@nuclia.com.
5
- #
6
- # AGPL:
7
- # This program is free software: you can redistribute it and/or modify
8
- # it under the terms of the GNU Affero General Public License as
9
- # published by the Free Software Foundation, either version 3 of the
10
- # License, or (at your option) any later version.
11
- #
12
- # This program is distributed in the hope that it will be useful,
13
- # but WITHOUT ANY WARRANTY; without even the implied warranty of
14
- # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15
- # GNU Affero General Public License for more details.
16
- #
17
- # You should have received a copy of the GNU Affero General Public License
18
- # along with this program. If not, see <http://www.gnu.org/licenses/>.
19
- #
20
- import asyncpg
21
- import pytest
22
- from pytest_docker_fixtures import images # type: ignore
23
-
24
- from nucliadb_utils.storages.pg import PostgresStorage
25
- from nucliadb_utils.store import MAIN
26
- from nucliadb_utils.utilities import Utility
27
-
28
- images.settings["postgresql"].update(
29
- {
30
- "version": "16.1",
31
- "env": {
32
- "POSTGRES_PASSWORD": "postgres",
33
- "POSTGRES_DB": "postgres",
34
- "POSTGRES_USER": "postgres",
35
- },
36
- }
37
- )
38
-
39
-
40
- @pytest.fixture(scope="function")
41
- async def pg_storage(pg):
42
- dsn = f"postgresql://postgres:postgres@{pg[0]}:{pg[1]}/postgres"
43
- storage = PostgresStorage(dsn)
44
- MAIN[Utility.STORAGE] = storage
45
- conn = await asyncpg.connect(dsn)
46
- await conn.execute(
47
- """
48
- DROP table IF EXISTS kb_files;
49
- DROP table IF EXISTS kb_files_fileparts;
50
- """
51
- )
52
- await conn.close()
53
- await storage.initialize()
54
- yield storage
55
- await storage.finalize()
56
- if Utility.STORAGE in MAIN:
57
- del MAIN[Utility.STORAGE]