sotkalib 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sotkalib/__init__.py +2 -4
- sotkalib/config/field.py +0 -16
- sotkalib/config/struct.py +121 -121
- sotkalib/enum/__init__.py +3 -0
- sotkalib/enum/mixins.py +59 -0
- sotkalib/exceptions/__init__.py +3 -0
- sotkalib/exceptions/api/__init__.py +1 -0
- sotkalib/exceptions/api/exc.py +53 -0
- sotkalib/exceptions/handlers/__init__.py +4 -0
- sotkalib/exceptions/handlers/args_incl_error.py +15 -0
- sotkalib/exceptions/handlers/core.py +33 -0
- sotkalib/http/__init__.py +12 -12
- sotkalib/http/client_session.py +217 -206
- sotkalib/log/factory.py +1 -18
- sotkalib/redis/__init__.py +8 -0
- sotkalib/redis/client.py +38 -0
- sotkalib/redis/lock.py +82 -0
- sotkalib/sqla/__init__.py +3 -0
- sotkalib/sqla/db.py +101 -0
- {sotkalib-0.0.2.dist-info → sotkalib-0.0.4.dist-info}/METADATA +3 -1
- sotkalib-0.0.4.dist-info/RECORD +25 -0
- {sotkalib-0.0.2.dist-info → sotkalib-0.0.4.dist-info}/WHEEL +1 -1
- sotkalib-0.0.2.dist-info/RECORD +0 -12
sotkalib/http/client_session.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
import importlib
|
|
3
2
|
import ssl
|
|
4
|
-
from collections.abc import Callable
|
|
3
|
+
from collections.abc import Callable, Mapping, Sequence
|
|
4
|
+
from functools import reduce
|
|
5
5
|
from http import HTTPStatus
|
|
6
6
|
from typing import Any, Literal, Protocol, Self
|
|
7
7
|
|
|
@@ -11,237 +11,248 @@ from pydantic import BaseModel, ConfigDict, Field
|
|
|
11
11
|
|
|
12
12
|
from sotkalib.log import get_logger
|
|
13
13
|
|
|
14
|
-
try:
|
|
15
|
-
certifi = importlib.import_module("certifi")
|
|
16
|
-
except ImportError:
|
|
17
|
-
certifi = None
|
|
18
|
-
|
|
19
|
-
|
|
20
14
|
MAXIMUM_BACKOFF: float = 120
|
|
21
15
|
|
|
16
|
+
try:
|
|
17
|
+
import certifi
|
|
18
|
+
except ImportError:
|
|
19
|
+
certifi = None
|
|
22
20
|
|
|
23
|
-
class
|
|
24
|
-
|
|
21
|
+
class RanOutOfAttemptsError(Exception):
|
|
22
|
+
pass
|
|
25
23
|
|
|
24
|
+
class CriticalStatusError(Exception):
|
|
25
|
+
pass
|
|
26
26
|
|
|
27
27
|
class StatusRetryError(Exception):
|
|
28
|
-
|
|
29
|
-
|
|
28
|
+
status: int
|
|
29
|
+
context: str
|
|
30
30
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
31
|
+
def __init__(self, status: int, context: str) -> None:
|
|
32
|
+
super().__init__(f"{status}: {context}")
|
|
33
|
+
self.status = status
|
|
34
|
+
self.context = context
|
|
35
35
|
|
|
36
|
+
type ExcArgFunc = Callable[..., tuple[Sequence[Any], Mapping[str, Any] | None]]
|
|
37
|
+
type StatArgFunc = Callable[..., Any]
|
|
36
38
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
39
|
+
async def default_stat_arg_func(resp: aiohttp.ClientResponse) -> tuple[Sequence[Any], None]:
|
|
40
|
+
return (f"[{resp.status}]; {await resp.text()=}",), None
|
|
40
41
|
|
|
41
42
|
class StatusSettings(BaseModel):
|
|
42
|
-
|
|
43
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
43
44
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
45
|
+
to_raise: set[HTTPStatus] = Field(default={HTTPStatus.FORBIDDEN})
|
|
46
|
+
to_retry: set[HTTPStatus] = Field(default={HTTPStatus.TOO_MANY_REQUESTS, HTTPStatus.FORBIDDEN})
|
|
47
|
+
exc_to_raise: type[Exception] = Field(default=CriticalStatusError)
|
|
48
|
+
not_found_as_none: bool = Field(default=True)
|
|
49
|
+
args_for_exc_func: StatArgFunc = Field(default=default_stat_arg_func)
|
|
50
|
+
unspecified: Literal["retry", "raise"] = Field(default="retry")
|
|
49
51
|
|
|
52
|
+
def default_exc_arg_func(exc: Exception, attempt: int, url: str, method: str, **kw) -> tuple[Sequence[Any], None]:
|
|
53
|
+
return (f"exception {type(exc)}: ({exc=}) {attempt=}; {url=} {method=} {kw=}",), None
|
|
50
54
|
|
|
51
55
|
class ExceptionSettings(BaseModel):
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
to_raise: tuple[type[Exception]] = Field(
|
|
55
|
-
default=(
|
|
56
|
-
client_exceptions.ConnectionTimeoutError,
|
|
57
|
-
client_exceptions.ClientProxyConnectionError,
|
|
58
|
-
client_exceptions.ContentTypeError,
|
|
59
|
-
),
|
|
60
|
-
)
|
|
56
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
61
57
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
),
|
|
70
|
-
)
|
|
58
|
+
to_raise: tuple[type[Exception]] = Field(
|
|
59
|
+
default=(
|
|
60
|
+
client_exceptions.ConnectionTimeoutError,
|
|
61
|
+
client_exceptions.ClientProxyConnectionError,
|
|
62
|
+
client_exceptions.ContentTypeError,
|
|
63
|
+
),
|
|
64
|
+
)
|
|
71
65
|
|
|
72
|
-
|
|
66
|
+
to_retry: tuple[type[Exception]] = Field(
|
|
67
|
+
default=(
|
|
68
|
+
TimeoutError,
|
|
69
|
+
client_exceptions.ServerDisconnectedError,
|
|
70
|
+
client_exceptions.ClientConnectionResetError,
|
|
71
|
+
client_exceptions.ClientOSError,
|
|
72
|
+
client_exceptions.ClientHttpProxyError,
|
|
73
|
+
),
|
|
74
|
+
)
|
|
73
75
|
|
|
74
|
-
|
|
76
|
+
exc_to_raise: type[Exception] | None = Field(default=None)
|
|
77
|
+
args_for_exc_func: ExcArgFunc = Field(default=default_exc_arg_func)
|
|
78
|
+
unspecified: Literal["retry", "raise"] = Field(default="retry")
|
|
75
79
|
|
|
76
80
|
|
|
77
|
-
class
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
81
|
+
class ClientSettings(BaseModel):
|
|
82
|
+
timeout: float = Field(default=5.0, gt=0)
|
|
83
|
+
base: float = Field(default=1.0, gt=0)
|
|
84
|
+
backoff: float = Field(default=2.0, gt=0)
|
|
85
|
+
maximum_retries: int = Field(default=3, ge=1)
|
|
82
86
|
|
|
83
|
-
|
|
87
|
+
useragent_factory: Callable[[], str] | None = Field(default=None)
|
|
84
88
|
|
|
85
|
-
|
|
86
|
-
|
|
89
|
+
status_settings: StatusSettings = Field(default_factory=StatusSettings)
|
|
90
|
+
exception_settings: ExceptionSettings = Field(default_factory=ExceptionSettings)
|
|
87
91
|
|
|
88
|
-
|
|
89
|
-
|
|
92
|
+
session_kwargs: dict[str, Any] = Field(default_factory=dict)
|
|
93
|
+
use_cookies_from_response: bool = Field(default=False)
|
|
90
94
|
|
|
91
95
|
|
|
92
|
-
class Handler[T](Protocol):
|
|
93
|
-
|
|
96
|
+
class Handler[**P, T](Protocol):
|
|
97
|
+
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...
|
|
94
98
|
|
|
95
99
|
|
|
96
|
-
type Middleware[T, R] = Callable[[Handler[T]], Handler[R]]
|
|
100
|
+
type Middleware[**P, T, R] = Callable[[Handler[P, T]], Handler[P, R]]
|
|
97
101
|
|
|
98
102
|
|
|
99
103
|
def _make_ssl_context(disable_tls13: bool = False) -> ssl.SSLContext:
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
104
|
+
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
105
|
+
ctx.load_default_certs()
|
|
106
|
+
|
|
107
|
+
if certifi:
|
|
108
|
+
ctx.load_verify_locations(certifi.where())
|
|
109
|
+
|
|
110
|
+
ctx.minimum_version = ssl.TLSVersion.TLSv1_2
|
|
111
|
+
ctx.maximum_version = ssl.TLSVersion.TLSv1_2 if disable_tls13 else ssl.TLSVersion.TLSv1_3
|
|
112
|
+
|
|
113
|
+
ctx.set_ciphers(
|
|
114
|
+
"TLS_AES_256_GCM_SHA384:TLS_AES_128_GCM_SHA256:"
|
|
115
|
+
"TLS_CHACHA20_POLY1305_SHA256:"
|
|
116
|
+
"ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:"
|
|
117
|
+
"ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
ctx.check_hostname = True
|
|
121
|
+
ctx.verify_mode = ssl.CERT_REQUIRED
|
|
122
|
+
|
|
123
|
+
return ctx
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class HTTPSession[R = aiohttp.ClientResponse | None]:
|
|
128
|
+
config: ClientSettings
|
|
129
|
+
_session: aiohttp.ClientSession
|
|
130
|
+
_middlewares: list[Middleware]
|
|
131
|
+
|
|
132
|
+
def __init__(
|
|
133
|
+
self,
|
|
134
|
+
config: ClientSettings | None = None,
|
|
135
|
+
_middlewares: list[Middleware] | None = None,
|
|
136
|
+
) -> None:
|
|
137
|
+
self.config = config if config is not None else ClientSettings()
|
|
138
|
+
self._session = None
|
|
139
|
+
self._middlewares = _middlewares or []
|
|
140
|
+
|
|
141
|
+
def use[**P, NewR](self, mw: Middleware[P, R, NewR]) -> HTTPSession[NewR]:
|
|
142
|
+
new_session: HTTPSession[NewR] = HTTPSession(
|
|
143
|
+
config=self.config,
|
|
144
|
+
_middlewares=[*self._middlewares, mw],
|
|
145
|
+
)
|
|
146
|
+
return new_session
|
|
147
|
+
|
|
148
|
+
async def __aenter__(self) -> Self:
|
|
149
|
+
ctx = _make_ssl_context(disable_tls13=False)
|
|
150
|
+
|
|
151
|
+
if self.config.session_kwargs.get("connector") is None:
|
|
152
|
+
self.config.session_kwargs["connector"] = aiohttp.TCPConnector(ssl=ctx)
|
|
153
|
+
if self.config.session_kwargs.get("trust_env") is None:
|
|
154
|
+
self.config.session_kwargs["trust_env"] = False
|
|
155
|
+
|
|
156
|
+
self._session = aiohttp.ClientSession(
|
|
157
|
+
timeout=aiohttp.ClientTimeout(total=self.config.timeout),
|
|
158
|
+
**self.config.session_kwargs,
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
get_logger("http.client_session").debug(
|
|
162
|
+
f"RetryableClientSession initialized with timeout: {self.config.timeout}"
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
return self
|
|
166
|
+
|
|
167
|
+
async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None:
|
|
168
|
+
if self._session:
|
|
169
|
+
await self._session.close()
|
|
170
|
+
|
|
171
|
+
async def _handle_statuses(self, response: aiohttp.ClientResponse) -> aiohttp.ClientResponse | None:
|
|
172
|
+
sc = response.status
|
|
173
|
+
exc, argfunc = self.config.status_settings.exc_to_raise, self.config.status_settings.args_for_exc_func
|
|
174
|
+
if self.config.use_cookies_from_response:
|
|
175
|
+
self._session.cookie_jar.update_cookies(response.cookies)
|
|
176
|
+
if sc in self.config.status_settings.to_retry:
|
|
177
|
+
raise StatusRetryError(status=sc, context=(await response.text()))
|
|
178
|
+
elif sc in self.config.status_settings.to_raise:
|
|
179
|
+
a, kw = await argfunc(response)
|
|
180
|
+
if kw is None:
|
|
181
|
+
raise exc(*a)
|
|
182
|
+
raise exc(*a, **kw)
|
|
183
|
+
elif self.config.status_settings.not_found_as_none and sc == HTTPStatus.NOT_FOUND:
|
|
184
|
+
return None
|
|
185
|
+
|
|
186
|
+
return response
|
|
187
|
+
|
|
188
|
+
def _get_make_request_func(self) -> Callable[..., Any]:
|
|
189
|
+
async def _make_request(*args: Any, **kwargs: Any) -> aiohttp.ClientResponse | None:
|
|
190
|
+
return await self._handle_statuses(await self._session.request(*args, **kwargs))
|
|
191
|
+
|
|
192
|
+
return reduce(lambda t, s: s(t), reversed(self._middlewares), _make_request)
|
|
193
|
+
|
|
194
|
+
async def _handle_request(
|
|
195
|
+
self,
|
|
196
|
+
method: str,
|
|
197
|
+
url: str,
|
|
198
|
+
make_request_func: Callable[..., Any],
|
|
199
|
+
**kw: Any,
|
|
200
|
+
) -> R:
|
|
201
|
+
if self.config.useragent_factory is not None:
|
|
202
|
+
user_agent_header = {"User-Agent": self.config.useragent_factory()}
|
|
203
|
+
kw["headers"] = kw.get("headers", {}) | user_agent_header
|
|
204
|
+
|
|
205
|
+
return await make_request_func(method, url, **kw)
|
|
206
|
+
|
|
207
|
+
async def _handle_retry(self, e: Exception, attempt: int, url: str, method: str, **kws: Any) -> None:
|
|
208
|
+
if attempt == self.config.maximum_retries:
|
|
209
|
+
raise RanOutOfAttemptsError(f"failed after {self.config.maximum_retries} retries: {type(e)} {e}") from e
|
|
210
|
+
|
|
211
|
+
await asyncio.sleep(self.config.base * min(MAXIMUM_BACKOFF, self.config.backoff**attempt))
|
|
212
|
+
|
|
213
|
+
async def _handle_to_raise(self, e: Exception, attempt: int, url: str, method: str, **kw: Any) -> None:
|
|
214
|
+
if self.config.exception_settings.exc_to_raise is None:
|
|
215
|
+
raise e
|
|
216
|
+
|
|
217
|
+
exc, argfunc = self.config.exception_settings.exc_to_raise, self.config.exception_settings.args_for_exc_func
|
|
218
|
+
|
|
219
|
+
a, exckw = argfunc(e, attempt, url, method, **kw)
|
|
220
|
+
if exckw is None:
|
|
221
|
+
raise exc(*a) from e
|
|
222
|
+
|
|
223
|
+
raise exc(*a, **exckw) from e
|
|
224
|
+
|
|
225
|
+
async def _handle_exception(self, e: Exception, attempt: int, url: str, method: str, **kw: Any) -> None:
|
|
226
|
+
if self.config.exception_settings.unspecified == "raise":
|
|
227
|
+
raise e
|
|
228
|
+
|
|
229
|
+
await self._handle_retry(e, attempt, url, method, **kw)
|
|
230
|
+
|
|
231
|
+
async def _request_with_retry(self, method: str, url: str, **kw: Any) -> R:
|
|
232
|
+
_make_request = self._get_make_request_func()
|
|
233
|
+
for attempt in range(self.config.maximum_retries + 1):
|
|
234
|
+
try:
|
|
235
|
+
return await self._handle_request(method, url, _make_request, **kw)
|
|
236
|
+
except self.config.exception_settings.to_retry + (StatusRetryError,) as e:
|
|
237
|
+
await self._handle_retry(e, attempt, url, method, **kw)
|
|
238
|
+
except self.config.exception_settings.to_raise as e:
|
|
239
|
+
await self._handle_to_raise(e, attempt, url, method, **kw)
|
|
240
|
+
except Exception as e:
|
|
241
|
+
await self._handle_exception(e, attempt, url, method, **kw)
|
|
242
|
+
|
|
243
|
+
return await _make_request()
|
|
244
|
+
|
|
245
|
+
async def get(self, url: str, **kwargs: Any) -> R:
|
|
246
|
+
return await self._request_with_retry("GET", url, **kwargs)
|
|
247
|
+
|
|
248
|
+
async def post(self, url: str, **kwargs: Any) -> R:
|
|
249
|
+
return await self._request_with_retry("POST", url, **kwargs)
|
|
250
|
+
|
|
251
|
+
async def put(self, url: str, **kwargs: Any) -> R:
|
|
252
|
+
return await self._request_with_retry("PUT", url, **kwargs)
|
|
253
|
+
|
|
254
|
+
async def delete(self, url: str, **kwargs: Any) -> R:
|
|
255
|
+
return await self._request_with_retry("DELETE", url, **kwargs)
|
|
256
|
+
|
|
257
|
+
async def patch(self, url: str, **kwargs: Any) -> R:
|
|
258
|
+
return await self._request_with_retry("PATCH", url, **kwargs)
|
sotkalib/log/factory.py
CHANGED
|
@@ -9,21 +9,4 @@ if TYPE_CHECKING:
|
|
|
9
9
|
|
|
10
10
|
@lru_cache
|
|
11
11
|
def get_logger(logger_name: str | None = None) -> Logger:
|
|
12
|
-
"""
|
|
13
|
-
|
|
14
|
-
Return a cached loguru Logger optionally bound with a humanized name.
|
|
15
|
-
|
|
16
|
-
If a name is provided, the returned logger is bound with extra["logger_name"]
|
|
17
|
-
in a " src -> sub -> leaf " format so it can be referenced in loguru sinks.
|
|
18
|
-
|
|
19
|
-
**Parameters:**
|
|
20
|
-
|
|
21
|
-
- `logger_name`: Dotted logger name (e.g., "src.database.service"). If None, return the global logger.
|
|
22
|
-
|
|
23
|
-
**Returns:**
|
|
24
|
-
|
|
25
|
-
A cached loguru Logger with the extra context bound when name is provided.
|
|
26
|
-
|
|
27
|
-
"""
|
|
28
|
-
|
|
29
|
-
return logger if logger_name is None else logger.bind(logger_name=f" {logger_name.replace('.', ' -> ')} ")
|
|
12
|
+
return logger if logger_name is None else logger.bind(name=logger_name.replace(".", " -> "))
|
sotkalib/redis/client.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from contextlib import AbstractAsyncContextManager
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, Field
|
|
5
|
+
from redis.asyncio import ConnectionPool, Redis
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class RedisPoolSettings(BaseModel):
|
|
9
|
+
uri: str = Field(default="redis://localhost:6379")
|
|
10
|
+
db_num: int = Field(default=4)
|
|
11
|
+
max_connections: int = Field(default=50)
|
|
12
|
+
socket_timeout: float = Field(default=5)
|
|
13
|
+
socket_connect_timeout: float = Field(default=5)
|
|
14
|
+
retry_on_timeout: bool = Field(default=True)
|
|
15
|
+
health_check_interval: float = Field(default=30)
|
|
16
|
+
decode_responses: bool = Field(default=True)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RedisPool(AbstractAsyncContextManager):
|
|
20
|
+
def __init__(self, settings: RedisPoolSettings | None = None):
|
|
21
|
+
if not settings:
|
|
22
|
+
settings = RedisPoolSettings()
|
|
23
|
+
|
|
24
|
+
self._pool = ConnectionPool.from_url(
|
|
25
|
+
settings.uri + "/" + str(settings.db_num),
|
|
26
|
+
**settings.model_dump(exclude={"uri", "db_num"}),
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
self._usage_counter = 0
|
|
30
|
+
self._usage_lock = asyncio.Lock()
|
|
31
|
+
|
|
32
|
+
async def __aenter__(self: RedisPool) -> Redis:
|
|
33
|
+
try:
|
|
34
|
+
return Redis(connection_pool=self._pool)
|
|
35
|
+
except Exception:
|
|
36
|
+
raise
|
|
37
|
+
|
|
38
|
+
async def __aexit__(self, exc_type, exc_value, traceback): ...
|
sotkalib/redis/lock.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from collections.abc import AsyncGenerator
|
|
3
|
+
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
|
4
|
+
from time import time
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from redis.asyncio import Redis
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ContextLockError(Exception):
|
|
11
|
+
def __init__(self, *args, can_retry: bool = True):
|
|
12
|
+
super().__init__(*args)
|
|
13
|
+
self.can_retry = can_retry
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
async def __try_acquire(rc: Redis, key_to_lock: str, acquire_timeout: int) -> bool:
|
|
17
|
+
"""Atomically acquire a lock using SET NX (set-if-not-exists)."""
|
|
18
|
+
return bool(await rc.set(key_to_lock, "acquired", nx=True, ex=acquire_timeout))
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
async def __wait_till_lock_free(
|
|
22
|
+
client: Redis,
|
|
23
|
+
key_to_lock: str,
|
|
24
|
+
lock_timeout: float = 10.0,
|
|
25
|
+
base_delay: float = 0.1,
|
|
26
|
+
max_delay: float = 5.0,
|
|
27
|
+
) -> None:
|
|
28
|
+
start = time()
|
|
29
|
+
attempt = 0
|
|
30
|
+
while await client.get(key_to_lock) is not None:
|
|
31
|
+
if (time() - start) > lock_timeout:
|
|
32
|
+
raise ContextLockError(
|
|
33
|
+
f"{key_to_lock} lock already acquired, timeout after {lock_timeout}s",
|
|
34
|
+
can_retry=False,
|
|
35
|
+
)
|
|
36
|
+
delay = min(base_delay * (2**attempt), max_delay)
|
|
37
|
+
await asyncio.sleep(delay)
|
|
38
|
+
attempt += 1
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@asynccontextmanager
|
|
42
|
+
async def redis_context_lock(
|
|
43
|
+
client: AbstractAsyncContextManager[Redis],
|
|
44
|
+
key_to_lock: str,
|
|
45
|
+
can_retry_if_lock_catched: bool = True,
|
|
46
|
+
wait_for_lock: bool = False,
|
|
47
|
+
wait_timeout: float = 60.0,
|
|
48
|
+
acquire_timeout: int = 5,
|
|
49
|
+
args_to_lock_exception: list[Any] | None = None,
|
|
50
|
+
) -> AsyncGenerator[None]:
|
|
51
|
+
"""
|
|
52
|
+
Acquire a Redis lock atomically using SET NX.
|
|
53
|
+
|
|
54
|
+
:param client: async context mng for redis
|
|
55
|
+
:param key_to_lock: Redis key for the lock
|
|
56
|
+
:param can_retry_if_lock_catched: Whether task should retry if lock is taken (only used if wait_for_lock=False)
|
|
57
|
+
:param wait_for_lock: If True, wait for lock to be free instead of immediately failing
|
|
58
|
+
:param wait_timeout: Maximum time to wait for lock in seconds (only used if wait_for_lock=True)
|
|
59
|
+
:param acquire_timeout: Timeout for acquiring lock
|
|
60
|
+
:param args_to_lock_exception: Args to pass to ContextLockError
|
|
61
|
+
|
|
62
|
+
"""
|
|
63
|
+
if args_to_lock_exception is None:
|
|
64
|
+
args_to_lock_exception = []
|
|
65
|
+
|
|
66
|
+
if wait_for_lock:
|
|
67
|
+
async with client as rc:
|
|
68
|
+
await __wait_till_lock_free(key_to_lock=key_to_lock, client=rc, lock_timeout=wait_timeout)
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
async with client as rc:
|
|
72
|
+
acquired = await __try_acquire(rc, key_to_lock, acquire_timeout)
|
|
73
|
+
if not acquired:
|
|
74
|
+
raise ContextLockError(
|
|
75
|
+
f"{key_to_lock} lock already acquired",
|
|
76
|
+
*args_to_lock_exception,
|
|
77
|
+
can_retry=can_retry_if_lock_catched,
|
|
78
|
+
)
|
|
79
|
+
yield
|
|
80
|
+
finally:
|
|
81
|
+
async with client as rc:
|
|
82
|
+
await rc.delete(key_to_lock)
|