python3-commons 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of python3-commons might be problematic. Click here for more details.
- python3_commons/__init__.py +7 -0
- python3_commons/api_client.py +120 -0
- python3_commons/audit.py +196 -0
- python3_commons/auth.py +82 -0
- python3_commons/cache.py +261 -0
- python3_commons/conf.py +52 -0
- python3_commons/db/__init__.py +83 -0
- python3_commons/db/helpers.py +62 -0
- python3_commons/db/models/__init__.py +2 -0
- python3_commons/db/models/auth.py +35 -0
- python3_commons/db/models/common.py +39 -0
- python3_commons/db/models/rbac.py +91 -0
- python3_commons/fs.py +10 -0
- python3_commons/helpers.py +108 -0
- python3_commons/logging/__init__.py +0 -0
- python3_commons/logging/filters.py +10 -0
- python3_commons/logging/formatters.py +25 -0
- python3_commons/object_storage.py +127 -0
- python3_commons/permissions.py +48 -0
- python3_commons/serializers/__init__.py +0 -0
- python3_commons/serializers/json.py +26 -0
- python3_commons/serializers/msgpack.py +50 -0
- python3_commons/serializers/msgspec.py +73 -0
- python3_commons-0.0.0.dist-info/METADATA +34 -0
- python3_commons-0.0.0.dist-info/RECORD +29 -0
- python3_commons-0.0.0.dist-info/WHEEL +5 -0
- python3_commons-0.0.0.dist-info/licenses/AUTHORS.rst +5 -0
- python3_commons-0.0.0.dist-info/licenses/LICENSE +604 -0
- python3_commons-0.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from contextlib import asynccontextmanager
|
|
3
|
+
from datetime import UTC, datetime
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from http import HTTPStatus
|
|
6
|
+
from json import dumps
|
|
7
|
+
from typing import AsyncGenerator, Literal, Mapping, Sequence
|
|
8
|
+
from uuid import uuid4
|
|
9
|
+
|
|
10
|
+
from aiohttp import ClientResponse, ClientSession, ClientTimeout, client_exceptions
|
|
11
|
+
from pydantic import HttpUrl
|
|
12
|
+
|
|
13
|
+
from python3_commons import audit
|
|
14
|
+
from python3_commons.conf import s3_settings
|
|
15
|
+
from python3_commons.helpers import request_to_curl
|
|
16
|
+
from python3_commons.serializers.json import CustomJSONEncoder
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
async def _store_response_for_audit(
|
|
22
|
+
response: ClientResponse, audit_name: str, uri_path: str, method: str, request_id: str
|
|
23
|
+
):
|
|
24
|
+
response_text = await response.text()
|
|
25
|
+
|
|
26
|
+
if response_text:
|
|
27
|
+
now = datetime.now(tz=UTC)
|
|
28
|
+
date_path = now.strftime('%Y/%m/%d')
|
|
29
|
+
timestamp = now.strftime('%H%M%S_%f')
|
|
30
|
+
|
|
31
|
+
await audit.write_audit_data(
|
|
32
|
+
s3_settings,
|
|
33
|
+
f'{date_path}/{audit_name}/{uri_path}/{method}_{timestamp}_{request_id}_response.txt',
|
|
34
|
+
response_text.encode('utf-8'),
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@asynccontextmanager
|
|
39
|
+
async def request(
|
|
40
|
+
client: ClientSession,
|
|
41
|
+
base_url: HttpUrl,
|
|
42
|
+
uri: str,
|
|
43
|
+
query: Mapping | None = None,
|
|
44
|
+
method: Literal['get', 'post', 'put', 'patch', 'options', 'head', 'delete'] = 'get',
|
|
45
|
+
headers: Mapping | None = None,
|
|
46
|
+
json: Mapping | Sequence | str | None = None,
|
|
47
|
+
data: bytes | None = None,
|
|
48
|
+
timeout: ClientTimeout | Enum | None = None,
|
|
49
|
+
audit_name: str | None = None,
|
|
50
|
+
) -> AsyncGenerator[ClientResponse]:
|
|
51
|
+
now = datetime.now(tz=UTC)
|
|
52
|
+
date_path = now.strftime('%Y/%m/%d')
|
|
53
|
+
timestamp = now.strftime('%H%M%S_%f')
|
|
54
|
+
request_id = str(uuid4())[-12:]
|
|
55
|
+
uri_path = uri[:-1] if uri.endswith('/') else uri
|
|
56
|
+
uri_path = uri_path[1:] if uri_path.startswith('/') else uri_path
|
|
57
|
+
url = f'{u[:-1] if (u := str(base_url)).endswith("/") else u}{uri}'
|
|
58
|
+
|
|
59
|
+
if audit_name:
|
|
60
|
+
curl_request = None
|
|
61
|
+
|
|
62
|
+
if method == 'get':
|
|
63
|
+
if headers or query:
|
|
64
|
+
curl_request = request_to_curl(url, query, method, headers)
|
|
65
|
+
else:
|
|
66
|
+
curl_request = request_to_curl(url, query, method, headers, json, data)
|
|
67
|
+
|
|
68
|
+
if curl_request:
|
|
69
|
+
await audit.write_audit_data(
|
|
70
|
+
s3_settings,
|
|
71
|
+
f'{date_path}/{audit_name}/{uri_path}/{method}_{timestamp}_{request_id}_request.txt',
|
|
72
|
+
curl_request.encode('utf-8'),
|
|
73
|
+
)
|
|
74
|
+
client_method = getattr(client, method)
|
|
75
|
+
|
|
76
|
+
logger.debug(f'Requesting {method} {url}')
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
if method == 'get':
|
|
80
|
+
async with client_method(url, params=query, headers=headers, timeout=timeout) as response:
|
|
81
|
+
if audit_name:
|
|
82
|
+
await _store_response_for_audit(response, audit_name, uri_path, method, request_id)
|
|
83
|
+
|
|
84
|
+
if response.ok:
|
|
85
|
+
yield response
|
|
86
|
+
else:
|
|
87
|
+
match response.status:
|
|
88
|
+
case HTTPStatus.UNAUTHORIZED:
|
|
89
|
+
raise PermissionError('Unauthorized')
|
|
90
|
+
case HTTPStatus.FORBIDDEN:
|
|
91
|
+
raise PermissionError('Forbidden')
|
|
92
|
+
case HTTPStatus.NOT_FOUND:
|
|
93
|
+
raise LookupError('Not found')
|
|
94
|
+
case HTTPStatus.BAD_REQUEST:
|
|
95
|
+
raise ValueError('Bad request')
|
|
96
|
+
case _:
|
|
97
|
+
response.raise_for_status()
|
|
98
|
+
else:
|
|
99
|
+
if json:
|
|
100
|
+
data = dumps(json, cls=CustomJSONEncoder).encode('utf-8')
|
|
101
|
+
|
|
102
|
+
if headers:
|
|
103
|
+
headers = {**headers, 'Content-Type': 'application/json'}
|
|
104
|
+
else:
|
|
105
|
+
headers = {'Content-Type': 'application/json'}
|
|
106
|
+
|
|
107
|
+
async with client_method(url, params=query, data=data, headers=headers, timeout=timeout) as response:
|
|
108
|
+
if audit_name:
|
|
109
|
+
await _store_response_for_audit(response, audit_name, uri_path, method, request_id)
|
|
110
|
+
|
|
111
|
+
yield response
|
|
112
|
+
except client_exceptions.ClientOSError as e:
|
|
113
|
+
if e.errno == 32:
|
|
114
|
+
raise ConnectionResetError('Broken pipe') from e
|
|
115
|
+
elif e.errno == 104:
|
|
116
|
+
raise ConnectionResetError('Connection reset by peer') from e
|
|
117
|
+
|
|
118
|
+
raise
|
|
119
|
+
except client_exceptions.ServerDisconnectedError as e:
|
|
120
|
+
raise ConnectionResetError('Server disconnected') from e
|
python3_commons/audit.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import io
|
|
3
|
+
import logging
|
|
4
|
+
import tarfile
|
|
5
|
+
from bz2 import BZ2Compressor
|
|
6
|
+
from collections import deque
|
|
7
|
+
from datetime import UTC, datetime, timedelta
|
|
8
|
+
from typing import Generator, Iterable
|
|
9
|
+
from uuid import uuid4
|
|
10
|
+
|
|
11
|
+
from lxml import etree
|
|
12
|
+
from minio import S3Error
|
|
13
|
+
from zeep.plugins import Plugin
|
|
14
|
+
from zeep.wsdl.definitions import AbstractOperation
|
|
15
|
+
|
|
16
|
+
from python3_commons import object_storage
|
|
17
|
+
from python3_commons.conf import S3Settings, s3_settings
|
|
18
|
+
from python3_commons.object_storage import ObjectStorage
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class GeneratedStream(io.BytesIO):
|
|
24
|
+
def __init__(self, generator: Generator[bytes, None, None], *args, **kwargs):
|
|
25
|
+
super().__init__(*args, **kwargs)
|
|
26
|
+
self.generator = generator
|
|
27
|
+
|
|
28
|
+
def read(self, size: int = -1):
|
|
29
|
+
if size < 0:
|
|
30
|
+
while True:
|
|
31
|
+
try:
|
|
32
|
+
chunk = next(self.generator)
|
|
33
|
+
except StopIteration:
|
|
34
|
+
break
|
|
35
|
+
else:
|
|
36
|
+
self.write(chunk)
|
|
37
|
+
else:
|
|
38
|
+
total_written_size = 0
|
|
39
|
+
|
|
40
|
+
while total_written_size < size:
|
|
41
|
+
try:
|
|
42
|
+
chunk = next(self.generator)
|
|
43
|
+
except StopIteration:
|
|
44
|
+
break
|
|
45
|
+
else:
|
|
46
|
+
total_written_size += self.write(chunk)
|
|
47
|
+
|
|
48
|
+
self.seek(0)
|
|
49
|
+
|
|
50
|
+
if chunk := super().read(size):
|
|
51
|
+
pos = self.tell()
|
|
52
|
+
|
|
53
|
+
buf = self.getbuffer()
|
|
54
|
+
unread_data_size = len(buf) - pos
|
|
55
|
+
|
|
56
|
+
if unread_data_size > 0:
|
|
57
|
+
buf[:unread_data_size] = buf[pos : pos + unread_data_size]
|
|
58
|
+
|
|
59
|
+
del buf
|
|
60
|
+
|
|
61
|
+
self.seek(0)
|
|
62
|
+
self.truncate(unread_data_size)
|
|
63
|
+
|
|
64
|
+
return chunk
|
|
65
|
+
|
|
66
|
+
def readable(self):
|
|
67
|
+
return True
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def generate_archive(
|
|
71
|
+
objects: Iterable[tuple[str, datetime, bytes]], chunk_size: int = 4096
|
|
72
|
+
) -> Generator[bytes, None, None]:
|
|
73
|
+
buffer = deque()
|
|
74
|
+
|
|
75
|
+
with tarfile.open(fileobj=buffer, mode='w') as archive:
|
|
76
|
+
for name, last_modified, content in objects:
|
|
77
|
+
logger.info(f'Adding {name} to archive')
|
|
78
|
+
info = tarfile.TarInfo(name)
|
|
79
|
+
info.size = len(content)
|
|
80
|
+
info.mtime = last_modified.timestamp()
|
|
81
|
+
archive.addfile(info, io.BytesIO(content))
|
|
82
|
+
|
|
83
|
+
buffer_length = buffer.tell()
|
|
84
|
+
|
|
85
|
+
while buffer_length >= chunk_size:
|
|
86
|
+
buffer.seek(0)
|
|
87
|
+
chunk = buffer.read(chunk_size)
|
|
88
|
+
chunk_len = len(chunk)
|
|
89
|
+
|
|
90
|
+
if not chunk:
|
|
91
|
+
break
|
|
92
|
+
|
|
93
|
+
yield chunk
|
|
94
|
+
|
|
95
|
+
buffer.seek(0)
|
|
96
|
+
buffer.truncate(chunk_len)
|
|
97
|
+
buffer.seek(0, io.SEEK_END)
|
|
98
|
+
buffer_length = buffer.tell()
|
|
99
|
+
|
|
100
|
+
while True:
|
|
101
|
+
chunk = buffer.read(chunk_size)
|
|
102
|
+
|
|
103
|
+
if not chunk:
|
|
104
|
+
break
|
|
105
|
+
|
|
106
|
+
yield chunk
|
|
107
|
+
|
|
108
|
+
buffer.seek(0)
|
|
109
|
+
buffer.truncate(0)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def generate_bzip2(chunks: Generator[bytes, None, None]) -> Generator[bytes, None, None]:
|
|
113
|
+
compressor = BZ2Compressor()
|
|
114
|
+
|
|
115
|
+
for chunk in chunks:
|
|
116
|
+
if compressed_chunk := compressor.compress(chunk):
|
|
117
|
+
yield compressed_chunk
|
|
118
|
+
|
|
119
|
+
if compressed_chunk := compressor.flush():
|
|
120
|
+
yield compressed_chunk
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def write_audit_data_sync(settings: S3Settings, key: str, data: bytes):
|
|
124
|
+
if settings.s3_secret_access_key:
|
|
125
|
+
try:
|
|
126
|
+
client = ObjectStorage(settings).get_client()
|
|
127
|
+
absolute_path = object_storage.get_absolute_path(f'audit/{key}')
|
|
128
|
+
|
|
129
|
+
client.put_object(settings.s3_bucket, absolute_path, io.BytesIO(data), len(data))
|
|
130
|
+
except S3Error as e:
|
|
131
|
+
logger.error(f'Failed storing object in storage: {e}')
|
|
132
|
+
else:
|
|
133
|
+
logger.debug(f'Stored object in storage: {key}')
|
|
134
|
+
else:
|
|
135
|
+
logger.debug(f'S3 is not configured, not storing object in storage: {key}')
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
async def write_audit_data(settings: S3Settings, key: str, data: bytes):
|
|
139
|
+
await asyncio.to_thread(write_audit_data_sync, settings, key, data)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
async def archive_audit_data(root_path: str = 'audit'):
|
|
143
|
+
now = datetime.now(tz=UTC) - timedelta(days=1)
|
|
144
|
+
year = now.year
|
|
145
|
+
month = now.month
|
|
146
|
+
day = now.day
|
|
147
|
+
bucket_name = s3_settings.s3_bucket
|
|
148
|
+
date_path = object_storage.get_absolute_path(f'{root_path}/{year}/{month:02}/{day:02}')
|
|
149
|
+
|
|
150
|
+
if objects := object_storage.get_objects(bucket_name, date_path, recursive=True):
|
|
151
|
+
logger.info(f'Compacting files in: {date_path}')
|
|
152
|
+
|
|
153
|
+
generator = generate_archive(objects, chunk_size=900_000)
|
|
154
|
+
bzip2_generator = generate_bzip2(generator)
|
|
155
|
+
archive_stream = GeneratedStream(bzip2_generator)
|
|
156
|
+
|
|
157
|
+
archive_path = object_storage.get_absolute_path(f'audit/.archive/{year}_{month:02}_{day:02}.tar.bz2')
|
|
158
|
+
object_storage.put_object(bucket_name, archive_path, archive_stream, -1, part_size=5 * 1024 * 1024)
|
|
159
|
+
|
|
160
|
+
if errors := object_storage.remove_objects(bucket_name, date_path):
|
|
161
|
+
for error in errors:
|
|
162
|
+
logger.error(f'Failed to delete object in {bucket_name=}: {error}')
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class ZeepAuditPlugin(Plugin):
|
|
166
|
+
def __init__(self, audit_name: str = 'zeep'):
|
|
167
|
+
super().__init__()
|
|
168
|
+
self.audit_name = audit_name
|
|
169
|
+
|
|
170
|
+
def store_audit_in_s3(self, envelope, operation: AbstractOperation, direction: str):
|
|
171
|
+
xml = etree.tostring(envelope, encoding='UTF-8', pretty_print=True)
|
|
172
|
+
now = datetime.now(tz=UTC)
|
|
173
|
+
date_path = now.strftime('%Y/%m/%d')
|
|
174
|
+
timestamp = now.strftime('%H%M%S')
|
|
175
|
+
path = f'{date_path}/{self.audit_name}/{operation.name}/{timestamp}_{str(uuid4())[-12:]}_{direction}.xml'
|
|
176
|
+
coro = write_audit_data(s3_settings, path, xml)
|
|
177
|
+
|
|
178
|
+
try:
|
|
179
|
+
loop = asyncio.get_running_loop()
|
|
180
|
+
except RuntimeError:
|
|
181
|
+
loop = None
|
|
182
|
+
|
|
183
|
+
if loop and loop.is_running():
|
|
184
|
+
loop.create_task(coro)
|
|
185
|
+
else:
|
|
186
|
+
asyncio.run(coro)
|
|
187
|
+
|
|
188
|
+
def ingress(self, envelope, http_headers, operation: AbstractOperation):
|
|
189
|
+
self.store_audit_in_s3(envelope, operation, 'ingress')
|
|
190
|
+
|
|
191
|
+
return envelope, http_headers
|
|
192
|
+
|
|
193
|
+
def egress(self, envelope, http_headers, operation: AbstractOperation, binding_options):
|
|
194
|
+
self.store_audit_in_s3(envelope, operation, 'egress')
|
|
195
|
+
|
|
196
|
+
return envelope, http_headers
|
python3_commons/auth.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from http import HTTPStatus
|
|
3
|
+
from typing import Annotated
|
|
4
|
+
|
|
5
|
+
import aiohttp
|
|
6
|
+
from fastapi import Depends, HTTPException
|
|
7
|
+
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
8
|
+
from jose import JWTError, jwt
|
|
9
|
+
from pydantic import BaseModel
|
|
10
|
+
|
|
11
|
+
from python3_commons.conf import oidc_settings
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TokenData(BaseModel):
|
|
17
|
+
sub: str
|
|
18
|
+
aud: str
|
|
19
|
+
exp: int
|
|
20
|
+
iss: str
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
OIDC_CONFIG_URL = f'{oidc_settings.authority_url}/.well-known/openid-configuration'
|
|
24
|
+
_JWKS: dict | None = None
|
|
25
|
+
|
|
26
|
+
bearer_security = HTTPBearer(auto_error=oidc_settings.enabled)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
async def fetch_openid_config() -> dict:
|
|
30
|
+
"""
|
|
31
|
+
Fetch the OpenID configuration (including JWKS URI) from OIDC authority.
|
|
32
|
+
"""
|
|
33
|
+
async with aiohttp.ClientSession() as session:
|
|
34
|
+
async with session.get(OIDC_CONFIG_URL) as response:
|
|
35
|
+
if response.status != HTTPStatus.OK:
|
|
36
|
+
raise HTTPException(
|
|
37
|
+
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail='Failed to fetch OpenID configuration'
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
return await response.json()
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
async def fetch_jwks(jwks_uri: str) -> dict:
|
|
44
|
+
"""
|
|
45
|
+
Fetch the JSON Web Key Set (JWKS) for validating the token's signature.
|
|
46
|
+
"""
|
|
47
|
+
async with aiohttp.ClientSession() as session:
|
|
48
|
+
async with session.get(jwks_uri) as response:
|
|
49
|
+
if response.status != HTTPStatus.OK:
|
|
50
|
+
raise HTTPException(status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail='Failed to fetch JWKS')
|
|
51
|
+
|
|
52
|
+
return await response.json()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
async def get_verified_token(
|
|
56
|
+
authorization: Annotated[HTTPAuthorizationCredentials, Depends(bearer_security)],
|
|
57
|
+
) -> TokenData | None:
|
|
58
|
+
"""
|
|
59
|
+
Verify the JWT access token using OIDC authority JWKS.
|
|
60
|
+
"""
|
|
61
|
+
global _JWKS
|
|
62
|
+
|
|
63
|
+
if not oidc_settings.enabled:
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
token = authorization.credentials
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
if not _JWKS:
|
|
70
|
+
openid_config = await fetch_openid_config()
|
|
71
|
+
_JWKS = await fetch_jwks(openid_config['jwks_uri'])
|
|
72
|
+
|
|
73
|
+
if oidc_settings.client_id:
|
|
74
|
+
payload = jwt.decode(token, _JWKS, algorithms=['RS256'], audience=oidc_settings.client_id)
|
|
75
|
+
else:
|
|
76
|
+
payload = jwt.decode(token, _JWKS, algorithms=['RS256'])
|
|
77
|
+
|
|
78
|
+
token_data = TokenData(**payload)
|
|
79
|
+
|
|
80
|
+
return token_data
|
|
81
|
+
except JWTError as e:
|
|
82
|
+
raise HTTPException(status_code=HTTPStatus.FORBIDDEN, detail=f'Token is invalid: {str(e)}')
|
python3_commons/cache.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import socket
|
|
3
|
+
from platform import platform
|
|
4
|
+
from typing import Any, Mapping, Sequence
|
|
5
|
+
|
|
6
|
+
import valkey
|
|
7
|
+
from pydantic import RedisDsn
|
|
8
|
+
from valkey.asyncio import ConnectionPool, Sentinel, StrictValkey, Valkey
|
|
9
|
+
from valkey.asyncio.retry import Retry
|
|
10
|
+
from valkey.backoff import FullJitterBackoff
|
|
11
|
+
from valkey.typing import ResponseT
|
|
12
|
+
|
|
13
|
+
from python3_commons.conf import valkey_settings
|
|
14
|
+
from python3_commons.helpers import SingletonMeta
|
|
15
|
+
from python3_commons.serializers.msgspec import (
|
|
16
|
+
deserialize_msgpack,
|
|
17
|
+
deserialize_msgpack_native,
|
|
18
|
+
serialize_msgpack,
|
|
19
|
+
serialize_msgpack_native,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class AsyncValkeyClient(metaclass=SingletonMeta):
|
|
26
|
+
def __init__(self, dsn: RedisDsn, sentinel_dsn: RedisDsn | None):
|
|
27
|
+
self._valkey_pool = None
|
|
28
|
+
self._valkey = None
|
|
29
|
+
|
|
30
|
+
if sentinel_dsn:
|
|
31
|
+
self._initialize_sentinel(sentinel_dsn)
|
|
32
|
+
else:
|
|
33
|
+
self._initialize_standard_pool(dsn)
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
def _get_keepalive_options():
|
|
37
|
+
if platform == 'linux' or platform == 'darwin':
|
|
38
|
+
return {socket.TCP_KEEPIDLE: 10, socket.TCP_KEEPINTVL: 5, socket.TCP_KEEPCNT: 5}
|
|
39
|
+
else:
|
|
40
|
+
return {}
|
|
41
|
+
|
|
42
|
+
def _initialize_sentinel(self, dsn: RedisDsn):
|
|
43
|
+
sentinel = Sentinel(
|
|
44
|
+
[(dsn.host, dsn.port)],
|
|
45
|
+
socket_connect_timeout=10,
|
|
46
|
+
socket_timeout=60,
|
|
47
|
+
password=dsn.password,
|
|
48
|
+
sentinel_kwargs={'password': dsn.password},
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
ka_options = self._get_keepalive_options()
|
|
52
|
+
|
|
53
|
+
self._valkey = sentinel.master_for(
|
|
54
|
+
'myprimary',
|
|
55
|
+
valkey_class=StrictValkey,
|
|
56
|
+
socket_connect_timeout=10,
|
|
57
|
+
socket_timeout=60,
|
|
58
|
+
health_check_interval=30,
|
|
59
|
+
retry_on_timeout=True,
|
|
60
|
+
retry=Retry(FullJitterBackoff(cap=5, base=1), 5),
|
|
61
|
+
socket_keepalive=True,
|
|
62
|
+
socket_keepalive_options=ka_options,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
def _initialize_standard_pool(self, dsn: RedisDsn):
|
|
66
|
+
self._valkey_pool = ConnectionPool.from_url(str(dsn))
|
|
67
|
+
self._valkey = StrictValkey(connection_pool=self._valkey_pool)
|
|
68
|
+
|
|
69
|
+
def get_client(self) -> Valkey:
|
|
70
|
+
return self._valkey
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def get_valkey_client() -> Valkey:
|
|
74
|
+
return AsyncValkeyClient(valkey_settings.dsn, valkey_settings.sentinel_dsn).get_client()
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
async def scan(
|
|
78
|
+
cursor: int = 0,
|
|
79
|
+
match: bytes | str | memoryview | None = None,
|
|
80
|
+
count: int | None = None,
|
|
81
|
+
_type: str | None = None,
|
|
82
|
+
**kwargs,
|
|
83
|
+
) -> ResponseT:
|
|
84
|
+
return await get_valkey_client().scan(cursor, match, count, _type, **kwargs)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
async def delete(*names: str | bytes | memoryview):
|
|
88
|
+
await get_valkey_client().delete(*names)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
async def store_bytes(name: str, data: bytes, ttl: int = None, if_not_set: bool = False):
|
|
92
|
+
r = get_valkey_client()
|
|
93
|
+
|
|
94
|
+
return await r.set(name, data, ex=ttl, nx=if_not_set)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
async def get_bytes(name: str) -> bytes | None:
|
|
98
|
+
r = get_valkey_client()
|
|
99
|
+
|
|
100
|
+
return await r.get(name)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
async def store(name: str, obj: Any, ttl: int = None, if_not_set: bool = False):
|
|
104
|
+
return await store_bytes(name, serialize_msgpack_native(obj), ttl, if_not_set)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
async def get(name: str, default=None, data_type: Any = None) -> Any:
|
|
108
|
+
if data := await get_bytes(name):
|
|
109
|
+
return deserialize_msgpack_native(data, data_type)
|
|
110
|
+
|
|
111
|
+
return default
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
async def store_string(name: str, data: str, ttl: int = None):
|
|
115
|
+
await store_bytes(name, data.encode(), ttl)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
async def get_string(name: str) -> str | None:
|
|
119
|
+
if data := await get_bytes(name):
|
|
120
|
+
return data.decode('utf-8')
|
|
121
|
+
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
async def store_sequence(name: str, data: Sequence, ttl: int = None):
|
|
126
|
+
if data:
|
|
127
|
+
try:
|
|
128
|
+
r = get_valkey_client()
|
|
129
|
+
await r.rpush(name, *map(serialize_msgpack_native, data))
|
|
130
|
+
|
|
131
|
+
if ttl:
|
|
132
|
+
await r.expire(name, ttl)
|
|
133
|
+
except valkey.exceptions.ConnectionError as e:
|
|
134
|
+
logger.error(f'Failed to store sequence in cache: {e}')
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
async def get_sequence(name: str, _type: type = list) -> Sequence:
|
|
138
|
+
r = get_valkey_client()
|
|
139
|
+
lrange = await r.lrange(name, 0, -1)
|
|
140
|
+
|
|
141
|
+
return _type(map(deserialize_msgpack_native, lrange))
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
async def store_dict(name: str, data: Mapping, ttl: int = None):
|
|
145
|
+
if data:
|
|
146
|
+
try:
|
|
147
|
+
r = get_valkey_client()
|
|
148
|
+
data = {k: serialize_msgpack_native(v) for k, v in data.items()}
|
|
149
|
+
await r.hset(name, mapping=data)
|
|
150
|
+
|
|
151
|
+
if ttl:
|
|
152
|
+
await r.expire(name, ttl)
|
|
153
|
+
except valkey.exceptions.ConnectionError as e:
|
|
154
|
+
logger.error(f'Failed to store dict in cache: {e}')
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
async def get_dict(name: str, value_data_type=None) -> dict | None:
|
|
158
|
+
r = get_valkey_client()
|
|
159
|
+
|
|
160
|
+
if data := await r.hgetall(name):
|
|
161
|
+
data = {k.decode(): deserialize_msgpack(v, value_data_type) for k, v in data.items()}
|
|
162
|
+
|
|
163
|
+
return data
|
|
164
|
+
|
|
165
|
+
return None
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
async def set_dict(name: str, mapping: dict, ttl: int = None):
|
|
169
|
+
if mapping:
|
|
170
|
+
try:
|
|
171
|
+
r = get_valkey_client()
|
|
172
|
+
mapping = {str(k): serialize_msgpack(v) for k, v in mapping.items()}
|
|
173
|
+
await r.hset(name, mapping=mapping)
|
|
174
|
+
|
|
175
|
+
if ttl:
|
|
176
|
+
await r.expire(name, ttl)
|
|
177
|
+
except valkey.exceptions.ConnectionError as e:
|
|
178
|
+
logger.error(f'Failed to set dict in cache: {e}')
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
async def get_dict_item(name: str, key: str, data_type=None, default=None):
|
|
182
|
+
try:
|
|
183
|
+
r = get_valkey_client()
|
|
184
|
+
|
|
185
|
+
if data := await r.hget(name, key):
|
|
186
|
+
return deserialize_msgpack_native(data, data_type)
|
|
187
|
+
|
|
188
|
+
return default
|
|
189
|
+
except valkey.exceptions.ConnectionError as e:
|
|
190
|
+
logger.error(f'Failed to get dict item from cache: {e}')
|
|
191
|
+
|
|
192
|
+
return None
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
async def set_dict_item(name: str, key: str, obj: Any):
|
|
196
|
+
try:
|
|
197
|
+
r = get_valkey_client()
|
|
198
|
+
await r.hset(name, key, serialize_msgpack_native(obj))
|
|
199
|
+
except valkey.exceptions.ConnectionError as e:
|
|
200
|
+
logger.error(f'Failed to set dict item in cache: {e}')
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
async def delete_dict_item(name: str, *keys):
|
|
204
|
+
try:
|
|
205
|
+
r = get_valkey_client()
|
|
206
|
+
await r.hdel(name, *keys)
|
|
207
|
+
except valkey.exceptions.ConnectionError as e:
|
|
208
|
+
logger.error(f'Failed to delete dict item from cache: {e}')
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
async def store_set(name: str, value: set, ttl: int = None):
|
|
212
|
+
try:
|
|
213
|
+
r = get_valkey_client()
|
|
214
|
+
await r.sadd(name, *map(serialize_msgpack_native, value))
|
|
215
|
+
|
|
216
|
+
if ttl:
|
|
217
|
+
await r.expire(name, ttl)
|
|
218
|
+
except valkey.exceptions.ConnectionError as e:
|
|
219
|
+
logger.error(f'Failed to store set in cache: {e}')
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
async def has_set_item(name: str, value: str) -> bool:
|
|
223
|
+
try:
|
|
224
|
+
r = get_valkey_client()
|
|
225
|
+
|
|
226
|
+
return await r.sismember(name, serialize_msgpack_native(value)) == 1
|
|
227
|
+
except valkey.exceptions.ConnectionError as e:
|
|
228
|
+
logger.error(f'Failed to check if set has item in cache: {e}')
|
|
229
|
+
|
|
230
|
+
return False
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
async def add_set_item(name: str, *values: str):
|
|
234
|
+
try:
|
|
235
|
+
r = get_valkey_client()
|
|
236
|
+
await r.sadd(name, *map(serialize_msgpack_native, values))
|
|
237
|
+
except valkey.exceptions.ConnectionError as e:
|
|
238
|
+
logger.error(f'Failed to add set item into cache: {e}')
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
async def delete_set_item(name: str, value: str):
|
|
242
|
+
r = get_valkey_client()
|
|
243
|
+
await r.srem(name, serialize_msgpack_native(value))
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
async def get_set_members(name: str) -> set[str] | None:
|
|
247
|
+
try:
|
|
248
|
+
r = get_valkey_client()
|
|
249
|
+
smembers = await r.smembers(name)
|
|
250
|
+
|
|
251
|
+
return set(map(deserialize_msgpack_native, smembers))
|
|
252
|
+
except valkey.exceptions.ConnectionError as e:
|
|
253
|
+
logger.error(f'Failed to get set members from cache: {e}')
|
|
254
|
+
|
|
255
|
+
return None
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
async def exists(name: str) -> bool:
|
|
259
|
+
r = get_valkey_client()
|
|
260
|
+
|
|
261
|
+
return await r.exists(name) == 1
|