sqlacache 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlacache/__init__.py +14 -0
- sqlacache/config.py +175 -0
- sqlacache/contrib/__init__.py +1 -0
- sqlacache/contrib/fastapi.py +1 -0
- sqlacache/contrib/prometheus.py +1 -0
- sqlacache/exceptions.py +13 -0
- sqlacache/interceptor.py +128 -0
- sqlacache/invalidation.py +45 -0
- sqlacache/manager.py +348 -0
- sqlacache/pubsub/__init__.py +3 -0
- sqlacache/pubsub/redis.py +89 -0
- sqlacache/py.typed +0 -0
- sqlacache/serializers/__init__.py +5 -0
- sqlacache/serializers/json.py +24 -0
- sqlacache/transport/__init__.py +45 -0
- sqlacache/transport/cashews.py +136 -0
- sqlacache/utils/__init__.py +1 -0
- sqlacache/utils/key_generation.py +40 -0
- sqlacache/utils/query_analysis.py +108 -0
- sqlacache/utils/sync_wrapper.py +16 -0
- sqlacache-0.1.0.dist-info/METADATA +298 -0
- sqlacache-0.1.0.dist-info/RECORD +24 -0
- sqlacache-0.1.0.dist-info/WHEEL +4 -0
- sqlacache-0.1.0.dist-info/licenses/LICENSE +21 -0
sqlacache/manager.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
1
|
+
"""Cache manager primitives."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
from contextlib import suppress
|
|
7
|
+
from typing import TYPE_CHECKING, Any, ClassVar
|
|
8
|
+
|
|
9
|
+
from sqlalchemy import event, select
|
|
10
|
+
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession
|
|
11
|
+
from sqlalchemy.orm import Session
|
|
12
|
+
|
|
13
|
+
from sqlacache.interceptor import (
|
|
14
|
+
build_bulk_delete_handler,
|
|
15
|
+
build_bulk_update_handler,
|
|
16
|
+
build_do_orm_execute_handler,
|
|
17
|
+
build_invalidation_handler,
|
|
18
|
+
handle_bulk_mutation,
|
|
19
|
+
merge_cached_result,
|
|
20
|
+
resolve_cached_result,
|
|
21
|
+
)
|
|
22
|
+
from sqlacache.invalidation import _bump_table_version, _get_table_version, generate_tags, invalidate_tags
|
|
23
|
+
from sqlacache.transport.cashews import CashewsTransport
|
|
24
|
+
from sqlacache.utils.key_generation import generate_cache_key
|
|
25
|
+
from sqlacache.utils.query_analysis import extract_model_from_statement, extract_pk_from_instance
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from sqlacache.pubsub.redis import RedisPubSub
|
|
29
|
+
from sqlacache.transport import CacheTransport
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class CacheManager:
|
|
33
|
+
"""Manage normalized cache configuration and runtime bindings."""
|
|
34
|
+
|
|
35
|
+
_async_get_patched = False
|
|
36
|
+
_original_async_get: Any = None
|
|
37
|
+
_engine_registry: ClassVar[dict[int, CacheManager]] = {}
|
|
38
|
+
|
|
39
|
+
def __init__(self, config: dict[str, Any]) -> None:
|
|
40
|
+
self._config = config
|
|
41
|
+
self._transport: CacheTransport | None = None
|
|
42
|
+
self._model_config: dict[str, dict[str, Any] | None] = dict(config.get("models", {}))
|
|
43
|
+
self._wildcard_config: dict[str, Any] | None = config.get("wildcard")
|
|
44
|
+
self._bound_engine: Any | None = None
|
|
45
|
+
self._bound_sync_engine: Any | None = None
|
|
46
|
+
self._pubsub_task: asyncio.Task[Any] | None = None
|
|
47
|
+
self._pubsub: RedisPubSub | None = None
|
|
48
|
+
self._listeners: list[tuple[Any, str, Any]] = []
|
|
49
|
+
self._pending_tasks: set[asyncio.Task[Any]] = set()
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def config(self) -> dict[str, Any]:
|
|
53
|
+
return self._config
|
|
54
|
+
|
|
55
|
+
async def bind(self, engine: Any) -> None:
|
|
56
|
+
if self._transport is not None:
|
|
57
|
+
await self.disconnect()
|
|
58
|
+
self._bound_engine = engine
|
|
59
|
+
self._bound_sync_engine = engine.sync_engine if isinstance(engine, AsyncEngine) else engine
|
|
60
|
+
if self._bound_sync_engine is not None:
|
|
61
|
+
self._engine_registry[id(self._bound_sync_engine)] = self
|
|
62
|
+
await self._ensure_transport()
|
|
63
|
+
self._patch_async_get()
|
|
64
|
+
self._register_listeners()
|
|
65
|
+
await self._maybe_setup_pubsub()
|
|
66
|
+
|
|
67
|
+
def get_model_config(self, model: type[Any]) -> dict[str, Any] | None:
|
|
68
|
+
model_path = f"{model.__module__}.{model.__name__}"
|
|
69
|
+
if model_path in self._model_config:
|
|
70
|
+
return self._model_config[model_path]
|
|
71
|
+
return self._wildcard_config
|
|
72
|
+
|
|
73
|
+
def is_enabled(self, model: type[Any], op: str) -> bool:
|
|
74
|
+
model_config = self.get_model_config(model)
|
|
75
|
+
if not model_config:
|
|
76
|
+
return False
|
|
77
|
+
return op in model_config["ops"]
|
|
78
|
+
|
|
79
|
+
async def execute(self, session: Any, statement: Any, timeout: int | None = None) -> Any:
|
|
80
|
+
await self._ensure_transport()
|
|
81
|
+
models = self._extract_models(statement)
|
|
82
|
+
execution_statement = statement.execution_options(sqlacache_skip_interceptor=True)
|
|
83
|
+
if not models:
|
|
84
|
+
return await session.execute(execution_statement)
|
|
85
|
+
|
|
86
|
+
key = await self._build_cache_key(statement, models)
|
|
87
|
+
transport = self._transport
|
|
88
|
+
assert transport is not None
|
|
89
|
+
cached = await transport.get(key)
|
|
90
|
+
if cached is not None:
|
|
91
|
+
return merge_cached_result(session, statement, cached)
|
|
92
|
+
|
|
93
|
+
result = await session.execute(execution_statement)
|
|
94
|
+
frozen = result.freeze()
|
|
95
|
+
primary_model = models[0]
|
|
96
|
+
model_config = self.get_model_config(primary_model)
|
|
97
|
+
expire = timeout or (model_config["timeout"] if model_config else self._config["default_timeout"])
|
|
98
|
+
from sqlacache.utils.query_analysis import extract_pks_from_fetch_result
|
|
99
|
+
|
|
100
|
+
pks_by_model = extract_pks_from_fetch_result(list(frozen.data), models)
|
|
101
|
+
tags: list[str] = []
|
|
102
|
+
for model, pks in pks_by_model.items():
|
|
103
|
+
tags.extend(generate_tags(model, pks))
|
|
104
|
+
await transport.set(key, frozen, expire=expire, tags=tags)
|
|
105
|
+
return merge_cached_result(session, statement, frozen)
|
|
106
|
+
|
|
107
|
+
async def invalidate(self, model: type[Any] | None = None, pks: list[Any] | None = None) -> None:
|
|
108
|
+
await self._ensure_transport()
|
|
109
|
+
if model is None:
|
|
110
|
+
await self.invalidate_all()
|
|
111
|
+
return
|
|
112
|
+
transport = self._transport
|
|
113
|
+
assert transport is not None
|
|
114
|
+
if pks:
|
|
115
|
+
await invalidate_tags(transport, *generate_tags(model, pks))
|
|
116
|
+
await self._publish_invalidation(model, pks, action="manual")
|
|
117
|
+
return
|
|
118
|
+
await _bump_table_version(transport, model)
|
|
119
|
+
await self._publish_invalidation(model, [], action="table")
|
|
120
|
+
|
|
121
|
+
async def invalidate_all(self) -> None:
|
|
122
|
+
await self._ensure_transport()
|
|
123
|
+
assert self._transport is not None
|
|
124
|
+
await self._transport.clear()
|
|
125
|
+
|
|
126
|
+
async def disconnect(self) -> None:
|
|
127
|
+
for target, identifier, listener in self._listeners:
|
|
128
|
+
with suppress(Exception):
|
|
129
|
+
event.remove(target, identifier, listener)
|
|
130
|
+
self._listeners.clear()
|
|
131
|
+
if self._pubsub_task is not None:
|
|
132
|
+
self._pubsub_task.cancel()
|
|
133
|
+
self._pubsub_task = None
|
|
134
|
+
if self._pubsub is not None:
|
|
135
|
+
await self._pubsub.disconnect()
|
|
136
|
+
self._pubsub = None
|
|
137
|
+
for task in list(self._pending_tasks):
|
|
138
|
+
task.cancel()
|
|
139
|
+
self._pending_tasks.clear()
|
|
140
|
+
if self._transport is not None:
|
|
141
|
+
await self._transport.disconnect()
|
|
142
|
+
self._transport = None
|
|
143
|
+
if self._bound_sync_engine is not None:
|
|
144
|
+
self._engine_registry.pop(id(self._bound_sync_engine), None)
|
|
145
|
+
if not self._engine_registry:
|
|
146
|
+
self._restore_async_get()
|
|
147
|
+
self._bound_engine = None
|
|
148
|
+
self._bound_sync_engine = None
|
|
149
|
+
|
|
150
|
+
def bind_sync(self, engine: Any) -> None:
|
|
151
|
+
raise NotImplementedError("Sync session support is deferred to v0.2.0")
|
|
152
|
+
|
|
153
|
+
def execute_sync(self, session: Any, statement: Any, timeout: int | None = None) -> Any:
|
|
154
|
+
return self._run_sync(self.execute(session, statement, timeout=timeout))
|
|
155
|
+
|
|
156
|
+
def invalidate_sync(self, model: type[Any] | None = None, pks: list[Any] | None = None) -> None:
|
|
157
|
+
self._run_sync(self.invalidate(model=model, pks=pks))
|
|
158
|
+
|
|
159
|
+
def _matches_session(self, session: Session) -> bool:
|
|
160
|
+
if self._bound_sync_engine is None:
|
|
161
|
+
return False
|
|
162
|
+
try:
|
|
163
|
+
bind = session.get_bind()
|
|
164
|
+
except Exception:
|
|
165
|
+
return False
|
|
166
|
+
return bind is self._bound_sync_engine
|
|
167
|
+
|
|
168
|
+
def _extract_models(self, statement: Any) -> list[type[Any]]:
|
|
169
|
+
extracted = extract_model_from_statement(statement)
|
|
170
|
+
if extracted is None:
|
|
171
|
+
return []
|
|
172
|
+
if isinstance(extracted, list):
|
|
173
|
+
return extracted
|
|
174
|
+
return [extracted]
|
|
175
|
+
|
|
176
|
+
async def _build_cache_key(self, statement: Any, models: list[type[Any]]) -> str:
|
|
177
|
+
prefix = self._config["prefix"]
|
|
178
|
+
base_key = generate_cache_key(statement, prefix=prefix)
|
|
179
|
+
if not models:
|
|
180
|
+
return base_key
|
|
181
|
+
transport = self._transport
|
|
182
|
+
assert transport is not None
|
|
183
|
+
versions = [str(await _get_table_version(transport, model)) for model in models]
|
|
184
|
+
return f"{base_key}:v{'.'.join(versions)}"
|
|
185
|
+
|
|
186
|
+
def _register_listeners(self) -> None:
|
|
187
|
+
from sqlacache.config import _resolve_model
|
|
188
|
+
|
|
189
|
+
select_listener = build_do_orm_execute_handler(self)
|
|
190
|
+
self._listeners.append((Session, "do_orm_execute", select_listener))
|
|
191
|
+
event.listen(Session, "do_orm_execute", select_listener, retval=True)
|
|
192
|
+
|
|
193
|
+
for model_path in self._model_config:
|
|
194
|
+
if model_path == "*":
|
|
195
|
+
continue
|
|
196
|
+
try:
|
|
197
|
+
model = _resolve_model(model_path)
|
|
198
|
+
except Exception:
|
|
199
|
+
continue
|
|
200
|
+
for identifier, action in (
|
|
201
|
+
("after_insert", "insert"),
|
|
202
|
+
("after_update", "update"),
|
|
203
|
+
("after_delete", "delete"),
|
|
204
|
+
):
|
|
205
|
+
listener = build_invalidation_handler(self, action)
|
|
206
|
+
self._listeners.append((model, identifier, listener))
|
|
207
|
+
event.listen(model, identifier, listener, propagate=True)
|
|
208
|
+
|
|
209
|
+
bulk_update_listener = build_bulk_update_handler(self)
|
|
210
|
+
self._listeners.append((Session, "after_bulk_update", bulk_update_listener))
|
|
211
|
+
event.listen(Session, "after_bulk_update", bulk_update_listener)
|
|
212
|
+
|
|
213
|
+
bulk_delete_listener = build_bulk_delete_handler(self)
|
|
214
|
+
self._listeners.append((Session, "after_bulk_delete", bulk_delete_listener))
|
|
215
|
+
event.listen(Session, "after_bulk_delete", bulk_delete_listener)
|
|
216
|
+
|
|
217
|
+
async def _maybe_setup_pubsub(self) -> None:
|
|
218
|
+
from sqlacache.pubsub.redis import RedisPubSub
|
|
219
|
+
|
|
220
|
+
backend = self._config["backend"]
|
|
221
|
+
redis_url = backend["url"]
|
|
222
|
+
if not isinstance(redis_url, str) or not redis_url.startswith("redis"):
|
|
223
|
+
return
|
|
224
|
+
self._pubsub = RedisPubSub(redis_url)
|
|
225
|
+
await self._pubsub.connect()
|
|
226
|
+
|
|
227
|
+
async def on_invalidate(event_payload: dict[str, Any]) -> None:
|
|
228
|
+
table = event_payload.get("table")
|
|
229
|
+
pks = event_payload.get("pks", [])
|
|
230
|
+
if not table:
|
|
231
|
+
return
|
|
232
|
+
transport = self._transport
|
|
233
|
+
assert transport is not None
|
|
234
|
+
if pks:
|
|
235
|
+
await invalidate_tags(transport, *(f"{table}:{pk}" for pk in pks))
|
|
236
|
+
else:
|
|
237
|
+
await _bump_table_version(transport, table)
|
|
238
|
+
|
|
239
|
+
await self._pubsub.listen(on_invalidate)
|
|
240
|
+
await self._pubsub.start()
|
|
241
|
+
|
|
242
|
+
async def _ensure_transport(self) -> None:
|
|
243
|
+
if self._transport is None:
|
|
244
|
+
self._transport = CashewsTransport.from_config(self._config)
|
|
245
|
+
await self._transport.connect()
|
|
246
|
+
|
|
247
|
+
async def _publish_invalidation(self, model: type[Any], pks: list[Any], action: str) -> None:
|
|
248
|
+
if self._pubsub is None:
|
|
249
|
+
return
|
|
250
|
+
await self._pubsub.publish(
|
|
251
|
+
{
|
|
252
|
+
"table": model.__tablename__,
|
|
253
|
+
"pks": pks,
|
|
254
|
+
"action": action,
|
|
255
|
+
"version": 1,
|
|
256
|
+
}
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
def _schedule_invalidation(self, model: type[Any], target: Any, action: str) -> None:
|
|
260
|
+
try:
|
|
261
|
+
asyncio.get_running_loop()
|
|
262
|
+
except RuntimeError:
|
|
263
|
+
return
|
|
264
|
+
pk = extract_pk_from_instance(target)
|
|
265
|
+
task = asyncio.create_task(self.invalidate(model=model, pks=[pk]))
|
|
266
|
+
self._pending_tasks.add(task)
|
|
267
|
+
task.add_done_callback(self._pending_tasks.discard)
|
|
268
|
+
|
|
269
|
+
def _schedule_table_bump(self, model: type[Any]) -> None:
|
|
270
|
+
try:
|
|
271
|
+
loop = asyncio.get_running_loop()
|
|
272
|
+
except RuntimeError:
|
|
273
|
+
return
|
|
274
|
+
transport = self._transport
|
|
275
|
+
if transport is None:
|
|
276
|
+
return
|
|
277
|
+
task = loop.create_task(_bump_table_version(transport, model))
|
|
278
|
+
self._pending_tasks.add(task)
|
|
279
|
+
task.add_done_callback(self._pending_tasks.discard)
|
|
280
|
+
|
|
281
|
+
def _handle_select(self, execute_state: Any) -> Any:
|
|
282
|
+
try:
|
|
283
|
+
asyncio.get_running_loop()
|
|
284
|
+
except RuntimeError:
|
|
285
|
+
return self._run_sync(resolve_cached_result(self, execute_state))
|
|
286
|
+
else:
|
|
287
|
+
return execute_state.invoke_statement()
|
|
288
|
+
|
|
289
|
+
def _handle_bulk_mutation(self, execute_state: Any) -> Any:
|
|
290
|
+
return self._run_sync(handle_bulk_mutation(self, execute_state))
|
|
291
|
+
|
|
292
|
+
@staticmethod
|
|
293
|
+
def _run_sync(coro: Any) -> Any:
|
|
294
|
+
try:
|
|
295
|
+
asyncio.get_running_loop()
|
|
296
|
+
except RuntimeError:
|
|
297
|
+
return asyncio.run(coro)
|
|
298
|
+
raise RuntimeError("Sync wrappers cannot be used while an event loop is already running")
|
|
299
|
+
|
|
300
|
+
@classmethod
|
|
301
|
+
def _patch_async_get(cls) -> None:
|
|
302
|
+
if cls._async_get_patched:
|
|
303
|
+
return
|
|
304
|
+
|
|
305
|
+
cls._original_async_get = AsyncSession.get
|
|
306
|
+
|
|
307
|
+
async def patched_get(self: AsyncSession, entity: Any, ident: Any, **kwargs: Any) -> Any:
|
|
308
|
+
bind = self.sync_session.get_bind(mapper=entity)
|
|
309
|
+
manager = cls._engine_registry.get(id(bind))
|
|
310
|
+
if manager is None or not manager.is_enabled(entity, "get"):
|
|
311
|
+
return await cls._original_async_get(self, entity, ident, **kwargs)
|
|
312
|
+
|
|
313
|
+
unsupported = (
|
|
314
|
+
kwargs.get("options"),
|
|
315
|
+
kwargs.get("populate_existing"),
|
|
316
|
+
kwargs.get("with_for_update"),
|
|
317
|
+
kwargs.get("identity_token"),
|
|
318
|
+
)
|
|
319
|
+
if any(unsupported):
|
|
320
|
+
return await cls._original_async_get(self, entity, ident, **kwargs)
|
|
321
|
+
|
|
322
|
+
statement = manager._build_get_statement(
|
|
323
|
+
entity,
|
|
324
|
+
ident,
|
|
325
|
+
execution_options=kwargs.get("execution_options"),
|
|
326
|
+
)
|
|
327
|
+
result = await manager.execute(self, statement)
|
|
328
|
+
return result.scalar_one_or_none()
|
|
329
|
+
|
|
330
|
+
setattr(AsyncSession, "get", patched_get) # noqa: B010
|
|
331
|
+
cls._async_get_patched = True
|
|
332
|
+
|
|
333
|
+
@classmethod
|
|
334
|
+
def _restore_async_get(cls) -> None:
|
|
335
|
+
if cls._async_get_patched and cls._original_async_get is not None:
|
|
336
|
+
setattr(AsyncSession, "get", cls._original_async_get) # noqa: B010
|
|
337
|
+
cls._async_get_patched = False
|
|
338
|
+
cls._original_async_get = None
|
|
339
|
+
|
|
340
|
+
@staticmethod
|
|
341
|
+
def _build_get_statement(entity: type[Any], ident: Any, execution_options: Any = None) -> Any:
|
|
342
|
+
mapper = entity.__mapper__
|
|
343
|
+
pk_columns = list(mapper.primary_key)
|
|
344
|
+
values = [ident] if len(pk_columns) == 1 and not isinstance(ident, tuple) else list(ident)
|
|
345
|
+
statement = select(entity).where(*[column == value for column, value in zip(pk_columns, values, strict=False)])
|
|
346
|
+
if execution_options:
|
|
347
|
+
statement = statement.execution_options(**dict(execution_options))
|
|
348
|
+
return statement
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
"""Redis pub/sub adapter for invalidation events."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
from contextlib import suppress
|
|
9
|
+
from typing import TYPE_CHECKING, Any
|
|
10
|
+
|
|
11
|
+
from redis.asyncio import Redis
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from collections.abc import Awaitable, Callable
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RedisPubSub:
|
|
20
|
+
"""Listen for and publish invalidation events over Redis pub/sub."""
|
|
21
|
+
|
|
22
|
+
def __init__(self, redis_url: str, channel: str = "sqlacache:invalidate") -> None:
|
|
23
|
+
self._redis_url = redis_url
|
|
24
|
+
self._channel = channel
|
|
25
|
+
self._client: Redis | None = None
|
|
26
|
+
self._pubsub: Any | None = None
|
|
27
|
+
self._task: asyncio.Task[None] | None = None
|
|
28
|
+
self._callbacks: list[Callable[[dict[str, Any]], Awaitable[None]]] = []
|
|
29
|
+
|
|
30
|
+
async def connect(self) -> None:
|
|
31
|
+
self._client = Redis.from_url(self._redis_url)
|
|
32
|
+
self._pubsub = self._client.pubsub()
|
|
33
|
+
await self._pubsub.subscribe(self._channel)
|
|
34
|
+
|
|
35
|
+
async def listen(self, callback: Callable[[dict[str, Any]], Awaitable[None]]) -> None:
|
|
36
|
+
self._callbacks.append(callback)
|
|
37
|
+
|
|
38
|
+
async def start(self) -> None:
|
|
39
|
+
if self._task is None:
|
|
40
|
+
self._task = asyncio.create_task(self._listen_loop())
|
|
41
|
+
|
|
42
|
+
async def _listen_loop(self) -> None:
|
|
43
|
+
if self._pubsub is None:
|
|
44
|
+
return
|
|
45
|
+
try:
|
|
46
|
+
async for message in self._pubsub.listen():
|
|
47
|
+
if message.get("type") != "message":
|
|
48
|
+
continue
|
|
49
|
+
raw_data = message.get("data")
|
|
50
|
+
try:
|
|
51
|
+
if isinstance(raw_data, bytes):
|
|
52
|
+
raw_data = raw_data.decode("utf-8")
|
|
53
|
+
event = json.loads(raw_data)
|
|
54
|
+
except (TypeError, json.JSONDecodeError):
|
|
55
|
+
logger.warning("Failed to parse pub/sub message: %r", message)
|
|
56
|
+
continue
|
|
57
|
+
|
|
58
|
+
for callback in list(self._callbacks):
|
|
59
|
+
try:
|
|
60
|
+
await callback(event)
|
|
61
|
+
except Exception:
|
|
62
|
+
logger.exception("Error in pub/sub callback")
|
|
63
|
+
except asyncio.CancelledError:
|
|
64
|
+
raise
|
|
65
|
+
except Exception:
|
|
66
|
+
logger.exception("Pub/sub listener crashed")
|
|
67
|
+
|
|
68
|
+
async def publish(self, event: dict[str, Any]) -> int:
|
|
69
|
+
if self._client is None:
|
|
70
|
+
raise RuntimeError("RedisPubSub is not connected")
|
|
71
|
+
message = json.dumps(event, default=str, sort_keys=True)
|
|
72
|
+
return int(await self._client.publish(self._channel, message))
|
|
73
|
+
|
|
74
|
+
async def stop(self) -> None:
|
|
75
|
+
if self._task is not None:
|
|
76
|
+
self._task.cancel()
|
|
77
|
+
with suppress(asyncio.CancelledError):
|
|
78
|
+
await self._task
|
|
79
|
+
self._task = None
|
|
80
|
+
if self._pubsub is not None:
|
|
81
|
+
await self._pubsub.unsubscribe(self._channel)
|
|
82
|
+
await self._pubsub.aclose()
|
|
83
|
+
self._pubsub = None
|
|
84
|
+
|
|
85
|
+
async def disconnect(self) -> None:
|
|
86
|
+
await self.stop()
|
|
87
|
+
if self._client is not None:
|
|
88
|
+
await self._client.aclose()
|
|
89
|
+
self._client = None
|
sqlacache/py.typed
ADDED
|
File without changes
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Model-aware JSON serialization helpers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from sqlalchemy import inspect
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ModelJSONSerializer:
|
|
12
|
+
"""Serialize ORM model instances using column attributes only."""
|
|
13
|
+
|
|
14
|
+
def serialize(self, obj: Any) -> str:
|
|
15
|
+
if isinstance(obj, list):
|
|
16
|
+
return json.dumps([self._instance_to_dict(item) for item in obj], default=str, sort_keys=True)
|
|
17
|
+
return json.dumps(self._instance_to_dict(obj), default=str, sort_keys=True)
|
|
18
|
+
|
|
19
|
+
def deserialize(self, payload: str) -> Any:
|
|
20
|
+
return json.loads(payload)
|
|
21
|
+
|
|
22
|
+
@staticmethod
|
|
23
|
+
def _instance_to_dict(instance: Any) -> dict[str, Any]:
|
|
24
|
+
return {attr.key: getattr(instance, attr.key) for attr in inspect(instance).mapper.column_attrs}
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
"""Transport abstractions."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Protocol
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from collections.abc import Sequence
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CacheTransport(Protocol):
|
|
12
|
+
"""Protocol for cache storage backends."""
|
|
13
|
+
|
|
14
|
+
async def connect(self) -> None:
|
|
15
|
+
"""Initialize the transport."""
|
|
16
|
+
|
|
17
|
+
async def disconnect(self) -> None:
|
|
18
|
+
"""Release backend resources."""
|
|
19
|
+
|
|
20
|
+
async def get(self, key: str) -> Any | None:
|
|
21
|
+
"""Retrieve a cached value."""
|
|
22
|
+
|
|
23
|
+
async def set(
|
|
24
|
+
self,
|
|
25
|
+
key: str,
|
|
26
|
+
value: Any,
|
|
27
|
+
expire: int,
|
|
28
|
+
tags: Sequence[str] | None = None,
|
|
29
|
+
) -> None:
|
|
30
|
+
"""Store a cached value."""
|
|
31
|
+
|
|
32
|
+
async def delete(self, *keys: str) -> None:
|
|
33
|
+
"""Delete one or more keys."""
|
|
34
|
+
|
|
35
|
+
async def delete_tags(self, *tags: str) -> None:
|
|
36
|
+
"""Invalidate keys associated with tags."""
|
|
37
|
+
|
|
38
|
+
async def clear(self) -> None:
|
|
39
|
+
"""Delete all cached entries."""
|
|
40
|
+
|
|
41
|
+
async def incr(self, key: str) -> int:
|
|
42
|
+
"""Atomically increment an integer counter and return the new value."""
|
|
43
|
+
|
|
44
|
+
async def is_available(self) -> bool:
|
|
45
|
+
"""Return whether the backend is usable."""
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
"""Cashews-based transport implementation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from collections.abc import Mapping, Sequence
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
from cashews import Cache
|
|
10
|
+
|
|
11
|
+
from sqlacache.exceptions import ConfigError, TransportError
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class CashewsTransport:
|
|
17
|
+
"""Thin async wrapper around ``cashews.Cache``."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, url: str, **kwargs: Any) -> None:
|
|
20
|
+
self._url = url
|
|
21
|
+
self._kwargs = dict(kwargs)
|
|
22
|
+
self._suppress = bool(self._kwargs.get("suppress", False))
|
|
23
|
+
self._cache = Cache()
|
|
24
|
+
self._connected = False
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def from_config(cls, config: Mapping[str, Any]) -> CashewsTransport:
|
|
28
|
+
"""Build a transport instance from normalized cache config."""
|
|
29
|
+
|
|
30
|
+
backend_config = config.get("backend", {})
|
|
31
|
+
if isinstance(backend_config, str):
|
|
32
|
+
url = backend_config
|
|
33
|
+
kwargs: dict[str, Any] = {}
|
|
34
|
+
elif isinstance(backend_config, Mapping):
|
|
35
|
+
url = backend_config.get("url", "mem://")
|
|
36
|
+
if not isinstance(url, str) or not url:
|
|
37
|
+
raise ConfigError("Backend mapping must define a non-empty 'url' string")
|
|
38
|
+
kwargs = {key: value for key, value in backend_config.items() if key != "url"}
|
|
39
|
+
else:
|
|
40
|
+
raise ConfigError("Backend must be configured as a URL string or mapping")
|
|
41
|
+
|
|
42
|
+
if "pickle_type" not in kwargs:
|
|
43
|
+
kwargs["pickle_type"] = config.get("serializer", "sqlalchemy")
|
|
44
|
+
|
|
45
|
+
return cls(url=url, **kwargs)
|
|
46
|
+
|
|
47
|
+
async def connect(self) -> None:
|
|
48
|
+
try:
|
|
49
|
+
self._cache.setup(self._url, **self._kwargs)
|
|
50
|
+
except Exception as exc:
|
|
51
|
+
if self._suppress:
|
|
52
|
+
logger.warning("Cache connect failed for %s: %s", self._url, exc)
|
|
53
|
+
return
|
|
54
|
+
raise TransportError(f"Cache connect failed for {self._url!r}") from exc
|
|
55
|
+
self._connected = True
|
|
56
|
+
|
|
57
|
+
async def disconnect(self) -> None:
|
|
58
|
+
try:
|
|
59
|
+
await self._cache.close()
|
|
60
|
+
except Exception as exc:
|
|
61
|
+
if self._suppress:
|
|
62
|
+
logger.warning("Cache disconnect failed for %s: %s", self._url, exc)
|
|
63
|
+
return
|
|
64
|
+
raise TransportError("Cache disconnect failed") from exc
|
|
65
|
+
finally:
|
|
66
|
+
self._connected = False
|
|
67
|
+
|
|
68
|
+
async def get(self, key: str) -> Any | None:
|
|
69
|
+
try:
|
|
70
|
+
return await self._cache.get(key)
|
|
71
|
+
except Exception as exc:
|
|
72
|
+
if self._suppress:
|
|
73
|
+
logger.warning("Cache get failed for key %s: %s", key, exc)
|
|
74
|
+
return None
|
|
75
|
+
raise TransportError(f"Cache get failed for key {key!r}") from exc
|
|
76
|
+
|
|
77
|
+
async def set(
|
|
78
|
+
self,
|
|
79
|
+
key: str,
|
|
80
|
+
value: Any,
|
|
81
|
+
expire: int,
|
|
82
|
+
tags: Sequence[str] | None = None,
|
|
83
|
+
) -> None:
|
|
84
|
+
try:
|
|
85
|
+
await self._cache.set(key, value, expire=expire, tags=tags or ())
|
|
86
|
+
except Exception as exc:
|
|
87
|
+
if self._suppress:
|
|
88
|
+
logger.warning("Cache set failed for key %s: %s", key, exc)
|
|
89
|
+
return
|
|
90
|
+
raise TransportError(f"Cache set failed for key {key!r}") from exc
|
|
91
|
+
|
|
92
|
+
async def delete(self, *keys: str) -> None:
|
|
93
|
+
try:
|
|
94
|
+
for key in keys:
|
|
95
|
+
await self._cache.delete(key)
|
|
96
|
+
except Exception as exc:
|
|
97
|
+
if self._suppress:
|
|
98
|
+
logger.warning("Cache delete failed for keys %s: %s", keys, exc)
|
|
99
|
+
return
|
|
100
|
+
raise TransportError(f"Cache delete failed for keys {keys!r}") from exc
|
|
101
|
+
|
|
102
|
+
async def delete_tags(self, *tags: str) -> None:
|
|
103
|
+
try:
|
|
104
|
+
await self._cache.delete_tags(*tags)
|
|
105
|
+
except Exception as exc:
|
|
106
|
+
if self._suppress:
|
|
107
|
+
logger.warning("Cache delete_tags failed for tags %s: %s", tags, exc)
|
|
108
|
+
return
|
|
109
|
+
raise TransportError(f"Cache delete_tags failed for tags {tags!r}") from exc
|
|
110
|
+
|
|
111
|
+
async def clear(self) -> None:
|
|
112
|
+
try:
|
|
113
|
+
await self._cache.clear()
|
|
114
|
+
except Exception as exc:
|
|
115
|
+
if self._suppress:
|
|
116
|
+
logger.warning("Cache clear failed for %s: %s", self._url, exc)
|
|
117
|
+
return
|
|
118
|
+
raise TransportError("Cache clear failed") from exc
|
|
119
|
+
|
|
120
|
+
async def incr(self, key: str) -> int:
|
|
121
|
+
try:
|
|
122
|
+
return int(await self._cache.incr(key))
|
|
123
|
+
except Exception as exc:
|
|
124
|
+
if self._suppress:
|
|
125
|
+
logger.warning("Cache incr failed for key %s: %s", key, exc)
|
|
126
|
+
return 0
|
|
127
|
+
raise TransportError(f"Cache incr failed for key {key!r}") from exc
|
|
128
|
+
|
|
129
|
+
async def is_available(self) -> bool:
|
|
130
|
+
if not self._connected:
|
|
131
|
+
return False
|
|
132
|
+
try:
|
|
133
|
+
await self._cache.get("__sqlacache_healthcheck__")
|
|
134
|
+
except Exception:
|
|
135
|
+
return False
|
|
136
|
+
return True
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Utility helpers for sqlacache."""
|