sqlphilosophy 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sqlphilosophy/VERSION +1 -0
- sqlphilosophy/__init__.py +3 -0
- sqlphilosophy/aio/__init__.py +3 -0
- sqlphilosophy/aio/protocols.py +26 -0
- sqlphilosophy/aio/query.py +396 -0
- sqlphilosophy/aio/repository.py +400 -0
- sqlphilosophy/audit/__init__.py +3 -0
- sqlphilosophy/audit/context.py +37 -0
- sqlphilosophy/audit/fields.py +24 -0
- sqlphilosophy/audit/listener.py +99 -0
- sqlphilosophy/audit/model.py +59 -0
- sqlphilosophy/py.typed +0 -0
- sqlphilosophy/sorting.py +97 -0
- sqlphilosophy/sql.py +532 -0
- sqlphilosophy/sync/__init__.py +3 -0
- sqlphilosophy/sync/protocols.py +26 -0
- sqlphilosophy/sync/query.py +392 -0
- sqlphilosophy/sync/repository.py +360 -0
- sqlphilosophy/types.py +61 -0
- sqlphilosophy-0.1.0.dist-info/METADATA +134 -0
- sqlphilosophy-0.1.0.dist-info/RECORD +24 -0
- sqlphilosophy-0.1.0.dist-info/WHEEL +5 -0
- sqlphilosophy-0.1.0.dist-info/licenses/LICENSE +21 -0
- sqlphilosophy-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,400 @@
|
|
|
1
|
+
"""Generic async session-bound repository for ORM CRUD."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from typing import Any
|
|
6
|
+
from typing import cast
|
|
7
|
+
from sqlalchemy import delete
|
|
8
|
+
from sqlalchemy import func
|
|
9
|
+
from sqlalchemy import inspect as sa_inspect
|
|
10
|
+
from sqlalchemy import select
|
|
11
|
+
from sqlalchemy import update
|
|
12
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
13
|
+
from sqlalchemy.orm import DeclarativeBase
|
|
14
|
+
from sqlalchemy.orm.interfaces import LoaderOption
|
|
15
|
+
from sqlphilosophy.aio.protocols import AsyncRepositoryFactory
|
|
16
|
+
from sqlphilosophy.aio.query import AsyncSqlAlchemyStatementBuilder
|
|
17
|
+
from sqlphilosophy.aio.query import AsyncStatementQueryBuilder
|
|
18
|
+
from sqlphilosophy.audit.model import AuditMixin
|
|
19
|
+
from sqlphilosophy.sorting import ListQuery
|
|
20
|
+
from sqlphilosophy.sorting import SortConfig
|
|
21
|
+
from sqlphilosophy.sql import rows_mapping
|
|
22
|
+
from sqlphilosophy.types import IdList
|
|
23
|
+
from sqlphilosophy.types import PrimaryKey
|
|
24
|
+
from sqlphilosophy.types import RowMapping
|
|
25
|
+
from sqlphilosophy.types import RowValue
|
|
26
|
+
from sqlphilosophy.types import SqlBindParams
|
|
27
|
+
from sqlphilosophy.types import SqlSelect
|
|
28
|
+
|
|
29
|
+
LoadRelations = Sequence[LoaderOption]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class AsyncBaseRepository[T: DeclarativeBase]:
|
|
33
|
+
"""Async session-scoped CRUD helpers for a single mapped model."""
|
|
34
|
+
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
model: type[T],
|
|
38
|
+
session: AsyncSession,
|
|
39
|
+
factory: AsyncRepositoryFactory | None = None,
|
|
40
|
+
) -> None:
|
|
41
|
+
self.model = model
|
|
42
|
+
self.session = session
|
|
43
|
+
self._factory = factory
|
|
44
|
+
pk_cols = self.inspect_model(model).primary_key
|
|
45
|
+
if len(pk_cols) != 1:
|
|
46
|
+
raise TypeError(f"{model.__name__} must have a single-column primary key")
|
|
47
|
+
self._pk_column = pk_cols[0]
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
def inspect_model(cls, model: type[DeclarativeBase]) -> Any:
|
|
51
|
+
"""Return SQLAlchemy ORM mapper inspection for ``model``."""
|
|
52
|
+
return sa_inspect(model)
|
|
53
|
+
|
|
54
|
+
def inspect(self) -> Any:
|
|
55
|
+
"""Return SQLAlchemy ORM mapper inspection for this repository's model."""
|
|
56
|
+
return self.inspect_model(self.model)
|
|
57
|
+
|
|
58
|
+
async def list_table_names(self) -> frozenset[str]:
|
|
59
|
+
"""Return visible table names on the session connection."""
|
|
60
|
+
connection = await self.session.connection()
|
|
61
|
+
|
|
62
|
+
def _names(sync_conn: object) -> frozenset[str]:
|
|
63
|
+
return frozenset(sa_inspect(sync_conn).get_table_names())
|
|
64
|
+
|
|
65
|
+
return await connection.run_sync(_names)
|
|
66
|
+
|
|
67
|
+
async def has_table(self, table_name: str) -> bool:
|
|
68
|
+
"""True when ``table_name`` exists on the session connection."""
|
|
69
|
+
return table_name in await self.list_table_names()
|
|
70
|
+
|
|
71
|
+
def _apply_load_relations(self, stmt: Any, load_relations: LoadRelations | None) -> Any:
|
|
72
|
+
if load_relations:
|
|
73
|
+
return stmt.options(*load_relations)
|
|
74
|
+
return stmt
|
|
75
|
+
|
|
76
|
+
async def _scalar_result(self, stmt: Any, *, unique: bool = False) -> Any:
|
|
77
|
+
result = await self.session.scalars(stmt)
|
|
78
|
+
if unique:
|
|
79
|
+
return result.unique()
|
|
80
|
+
return result
|
|
81
|
+
|
|
82
|
+
async def fetch_statement_mappings(
|
|
83
|
+
self, stmt: Any, params: RowMapping | None = None
|
|
84
|
+
) -> list[RowMapping]:
|
|
85
|
+
"""Execute ``stmt`` and return all rows as mappings."""
|
|
86
|
+
result = await self.session.execute(stmt, params or {})
|
|
87
|
+
mapped = result.mappings()
|
|
88
|
+
rows = mapped.all() if hasattr(mapped, "all") else mapped
|
|
89
|
+
return rows_mapping(rows)
|
|
90
|
+
|
|
91
|
+
async def scalar_count(self, stmt: SqlSelect, params: SqlBindParams | None = None) -> int:
|
|
92
|
+
"""Execute a scalar count/select statement and return ``int``."""
|
|
93
|
+
result = await self.session.execute(stmt, params or {})
|
|
94
|
+
return int(result.scalar_one())
|
|
95
|
+
|
|
96
|
+
async def iter_mappings(self, stmt: SqlSelect, params: SqlBindParams | None = None):
|
|
97
|
+
"""Yield each result row as a plain ``dict``."""
|
|
98
|
+
result = await self.session.execute(stmt, params or {})
|
|
99
|
+
for row in result.mappings():
|
|
100
|
+
yield dict(row)
|
|
101
|
+
|
|
102
|
+
async def fetch_mapping_first(
|
|
103
|
+
self, stmt: SqlSelect, params: SqlBindParams | None = None
|
|
104
|
+
) -> RowMapping | None:
|
|
105
|
+
"""Execute ``stmt`` and return the first row as a mapping, or ``None``."""
|
|
106
|
+
result = await self.session.execute(stmt, params or {})
|
|
107
|
+
row = result.mappings().first()
|
|
108
|
+
return dict(row) if row is not None else None
|
|
109
|
+
|
|
110
|
+
async def fetch_mapping_one(
|
|
111
|
+
self, stmt: SqlSelect, params: SqlBindParams | None = None
|
|
112
|
+
) -> RowMapping:
|
|
113
|
+
"""Execute ``stmt`` and return exactly one row as a mapping."""
|
|
114
|
+
result = await self.session.execute(stmt, params or {})
|
|
115
|
+
return dict(result.mappings().one())
|
|
116
|
+
|
|
117
|
+
async def fetch_mappings_page(
|
|
118
|
+
self,
|
|
119
|
+
stmt: Any,
|
|
120
|
+
*,
|
|
121
|
+
limit: int,
|
|
122
|
+
offset: int,
|
|
123
|
+
params: RowMapping | None = None,
|
|
124
|
+
) -> list[RowMapping]:
|
|
125
|
+
"""Execute ``stmt`` with limit/offset; return normalized row mappings."""
|
|
126
|
+
if limit < 0:
|
|
127
|
+
raise ValueError("limit must be >= 0")
|
|
128
|
+
if offset < 0:
|
|
129
|
+
raise ValueError("offset must be >= 0")
|
|
130
|
+
paged = stmt.limit(limit).offset(offset)
|
|
131
|
+
return await self.fetch_statement_mappings(paged, params)
|
|
132
|
+
|
|
133
|
+
async def fetch_sorted_mappings(
|
|
134
|
+
self,
|
|
135
|
+
stmt: Any,
|
|
136
|
+
*,
|
|
137
|
+
list_query: ListQuery,
|
|
138
|
+
params: RowMapping | None = None,
|
|
139
|
+
sort: SortConfig | None = None,
|
|
140
|
+
) -> list[RowMapping]:
|
|
141
|
+
"""Apply optional sort, then return one page of row mappings."""
|
|
142
|
+
if sort is not None:
|
|
143
|
+
stmt = stmt.order_by(*sort.order_clauses(list_query.order_by))
|
|
144
|
+
return await self.fetch_mappings_page(
|
|
145
|
+
stmt,
|
|
146
|
+
limit=list_query.limit,
|
|
147
|
+
offset=list_query.offset,
|
|
148
|
+
params=params,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
async def get_by_id(
|
|
152
|
+
self, obj_id: PrimaryKey, load_relations: LoadRelations | None = None
|
|
153
|
+
) -> T | None:
|
|
154
|
+
"""Fetch a single record by primary key with optional eager loading."""
|
|
155
|
+
stmt = select(self.model).where(self._pk_column == obj_id)
|
|
156
|
+
stmt = self._apply_load_relations(stmt, load_relations)
|
|
157
|
+
result = await self.session.scalars(stmt)
|
|
158
|
+
return result.first()
|
|
159
|
+
|
|
160
|
+
async def exists(self, obj_id: PrimaryKey) -> bool:
|
|
161
|
+
"""True when a row exists for the primary key."""
|
|
162
|
+
return await self.get_by_id(obj_id) is not None
|
|
163
|
+
|
|
164
|
+
async def exists_where(self, **filters: object) -> bool:
|
|
165
|
+
"""True when at least one row matches optional equality filters."""
|
|
166
|
+
return await self.count(**filters) > 0
|
|
167
|
+
|
|
168
|
+
async def count(self, **filters: object) -> int:
|
|
169
|
+
"""Count rows matching optional equality filters."""
|
|
170
|
+
stmt = select(func.count()).select_from(self.model)
|
|
171
|
+
if filters:
|
|
172
|
+
stmt = stmt.filter_by(**filters)
|
|
173
|
+
result = await self.session.scalar(stmt)
|
|
174
|
+
return int(result or 0)
|
|
175
|
+
|
|
176
|
+
async def first(
|
|
177
|
+
self, load_relations: LoadRelations | None = None, **filters: object
|
|
178
|
+
) -> T | None:
|
|
179
|
+
"""Return the first row matching filters, with optional eager loading."""
|
|
180
|
+
stmt = select(self.model).filter_by(**filters).limit(1)
|
|
181
|
+
stmt = self._apply_load_relations(stmt, load_relations)
|
|
182
|
+
result = await self.session.scalars(stmt)
|
|
183
|
+
return result.first()
|
|
184
|
+
|
|
185
|
+
async def get(self, obj_id: PrimaryKey, load_relations: LoadRelations | None = None) -> T:
|
|
186
|
+
"""Fetch a single record by primary key; raise if missing."""
|
|
187
|
+
obj = await self.get_by_id(obj_id, load_relations=load_relations)
|
|
188
|
+
if obj is None:
|
|
189
|
+
raise LookupError(f"{self.model.__name__} matching id={obj_id!r} not found")
|
|
190
|
+
return obj
|
|
191
|
+
|
|
192
|
+
async def get_many(
|
|
193
|
+
self, ids: Sequence[PrimaryKey], load_relations: LoadRelations | None = None
|
|
194
|
+
) -> Sequence[T]:
|
|
195
|
+
"""Fetch multiple records by primary key."""
|
|
196
|
+
if not ids:
|
|
197
|
+
return []
|
|
198
|
+
stmt = select(self.model).where(self._pk_column.in_(ids))
|
|
199
|
+
stmt = self._apply_load_relations(stmt, load_relations)
|
|
200
|
+
result = await self._scalar_result(stmt, unique=load_relations is not None)
|
|
201
|
+
return result.all()
|
|
202
|
+
|
|
203
|
+
async def filter(
|
|
204
|
+
self,
|
|
205
|
+
*,
|
|
206
|
+
page: int = 1,
|
|
207
|
+
limit: int | None = None,
|
|
208
|
+
load_relations: LoadRelations | None = None,
|
|
209
|
+
**filters: object,
|
|
210
|
+
) -> Sequence[T]:
|
|
211
|
+
"""Return rows matching optional equality filters, optionally paginated."""
|
|
212
|
+
if page < 1:
|
|
213
|
+
raise ValueError("page must be >= 1")
|
|
214
|
+
if limit is not None and limit < 1:
|
|
215
|
+
raise ValueError("limit must be >= 1")
|
|
216
|
+
stmt = select(self.model).filter_by(**filters).order_by(self._pk_column)
|
|
217
|
+
if limit is not None:
|
|
218
|
+
stmt = stmt.limit(limit).offset((page - 1) * limit)
|
|
219
|
+
stmt = self._apply_load_relations(stmt, load_relations)
|
|
220
|
+
result = await self._scalar_result(stmt, unique=load_relations is not None)
|
|
221
|
+
return result.all()
|
|
222
|
+
|
|
223
|
+
async def get_all(
|
|
224
|
+
self,
|
|
225
|
+
*,
|
|
226
|
+
page: int = 1,
|
|
227
|
+
limit: int | None = None,
|
|
228
|
+
load_relations: LoadRelations | None = None,
|
|
229
|
+
) -> Sequence[T]:
|
|
230
|
+
"""Fetch records for this model type, optionally paginated by ``page`` and ``limit``."""
|
|
231
|
+
if page < 1:
|
|
232
|
+
raise ValueError("page must be >= 1")
|
|
233
|
+
if limit is not None and limit < 1:
|
|
234
|
+
raise ValueError("limit must be >= 1")
|
|
235
|
+
statement = select(self.model).order_by(self._pk_column)
|
|
236
|
+
if limit is not None:
|
|
237
|
+
statement = statement.limit(limit).offset((page - 1) * limit)
|
|
238
|
+
statement = self._apply_load_relations(statement, load_relations)
|
|
239
|
+
result = await self._scalar_result(statement, unique=load_relations is not None)
|
|
240
|
+
return result.all()
|
|
241
|
+
|
|
242
|
+
async def get_with_join(
|
|
243
|
+
self,
|
|
244
|
+
target_model: type[Any],
|
|
245
|
+
*filter_expressions: Any,
|
|
246
|
+
join_on: Any = None,
|
|
247
|
+
) -> Sequence[tuple[T, Any]]:
|
|
248
|
+
"""Explicit INNER JOIN returning ``(base_row, target_row)`` tuples."""
|
|
249
|
+
stmt = select(self.model, target_model)
|
|
250
|
+
if join_on is not None:
|
|
251
|
+
stmt = stmt.join(target_model, join_on)
|
|
252
|
+
else:
|
|
253
|
+
stmt = stmt.join(target_model) # pragma: no cover
|
|
254
|
+
if filter_expressions:
|
|
255
|
+
stmt = stmt.where(*filter_expressions)
|
|
256
|
+
result = await self.session.execute(stmt)
|
|
257
|
+
return result.all()
|
|
258
|
+
|
|
259
|
+
async def create(self, **fields: object) -> T:
|
|
260
|
+
"""Construct, stage, and flush a new instance."""
|
|
261
|
+
return await self.add(self.model(**fields))
|
|
262
|
+
|
|
263
|
+
async def get_or_create(
|
|
264
|
+
self,
|
|
265
|
+
*,
|
|
266
|
+
defaults: RowMapping | None = None,
|
|
267
|
+
**lookup: object,
|
|
268
|
+
) -> tuple[T, bool]:
|
|
269
|
+
"""Return ``(instance, created)`` for equality ``lookup`` filters."""
|
|
270
|
+
existing = await self.first(**lookup)
|
|
271
|
+
if existing is not None:
|
|
272
|
+
return existing, False
|
|
273
|
+
payload: RowMapping = {**(defaults or {}), **lookup}
|
|
274
|
+
return await self.create(**payload), True
|
|
275
|
+
|
|
276
|
+
async def add(self, obj: T) -> T:
|
|
277
|
+
"""Stage a new instance; caller commits in the orchestration layer."""
|
|
278
|
+
self.session.add(obj)
|
|
279
|
+
await self.session.flush()
|
|
280
|
+
return obj
|
|
281
|
+
|
|
282
|
+
async def update_partial(
|
|
283
|
+
self,
|
|
284
|
+
obj_id: PrimaryKey,
|
|
285
|
+
fields: RowMapping,
|
|
286
|
+
writable: frozenset[str],
|
|
287
|
+
*,
|
|
288
|
+
touch_updated_on: bool = False,
|
|
289
|
+
) -> int:
|
|
290
|
+
"""Apply a partial update; returns affected row count (0 if none)."""
|
|
291
|
+
if issubclass(self.model, AuditMixin):
|
|
292
|
+
audit_updates = {k: v for k, v in fields.items() if k in writable}
|
|
293
|
+
if not audit_updates:
|
|
294
|
+
return 0
|
|
295
|
+
row = await self.session.get(self.model, obj_id)
|
|
296
|
+
if row is None:
|
|
297
|
+
return 0
|
|
298
|
+
for key, value in audit_updates.items():
|
|
299
|
+
setattr(row, key, value)
|
|
300
|
+
await self.session.flush()
|
|
301
|
+
return 1
|
|
302
|
+
core_updates: RowMapping = {k: v for k, v in fields.items() if k in writable}
|
|
303
|
+
if not core_updates:
|
|
304
|
+
return 0
|
|
305
|
+
if touch_updated_on:
|
|
306
|
+
core_updates = cast(
|
|
307
|
+
RowMapping,
|
|
308
|
+
{**dict(core_updates), "updated_on": cast(RowValue, func.now())},
|
|
309
|
+
)
|
|
310
|
+
pk_col = self._pk_column
|
|
311
|
+
stmt = update(self.model).where(pk_col == obj_id).values(**core_updates)
|
|
312
|
+
result = await self.session.execute(stmt)
|
|
313
|
+
return int(result.rowcount or 0)
|
|
314
|
+
|
|
315
|
+
async def update_where(
|
|
316
|
+
self,
|
|
317
|
+
*,
|
|
318
|
+
criteria: Sequence[object],
|
|
319
|
+
values: RowMapping,
|
|
320
|
+
params: SqlBindParams | None = None,
|
|
321
|
+
) -> int:
|
|
322
|
+
"""Bulk UPDATE rows matching ``criteria``; returns affected row count."""
|
|
323
|
+
if not values:
|
|
324
|
+
return 0
|
|
325
|
+
stmt = update(self.model).where(*criteria).values(**values)
|
|
326
|
+
result = await self.session.execute(stmt, params or {})
|
|
327
|
+
return int(result.rowcount or 0)
|
|
328
|
+
|
|
329
|
+
async def delete_where(
|
|
330
|
+
self,
|
|
331
|
+
*,
|
|
332
|
+
criteria: Sequence[object],
|
|
333
|
+
params: SqlBindParams | None = None,
|
|
334
|
+
) -> int:
|
|
335
|
+
"""Delete rows matching ``criteria`` via PK lookup + ``delete_many``."""
|
|
336
|
+
if not criteria:
|
|
337
|
+
return 0
|
|
338
|
+
pk_key = self._pk_column.key
|
|
339
|
+
builder = self.statement().select_columns(self._pk_column).where(*criteria)
|
|
340
|
+
if params:
|
|
341
|
+
builder = builder.with_params(params)
|
|
342
|
+
rows = await builder.mappings().all()
|
|
343
|
+
ids = [row[pk_key] for row in rows]
|
|
344
|
+
return await self.delete_many(ids)
|
|
345
|
+
|
|
346
|
+
async def remove(self, obj_id: PrimaryKey) -> bool:
|
|
347
|
+
"""Delete a record by primary key."""
|
|
348
|
+
statement = delete(self.model).where(self._pk_column == obj_id)
|
|
349
|
+
result = await self.session.execute(statement)
|
|
350
|
+
return bool(result.rowcount)
|
|
351
|
+
|
|
352
|
+
async def delete_many(self, ids: IdList) -> int:
|
|
353
|
+
"""Delete multiple records by primary key."""
|
|
354
|
+
if not ids:
|
|
355
|
+
return 0
|
|
356
|
+
stmt = delete(self.model).where(self._pk_column.in_(ids))
|
|
357
|
+
result = await self.session.execute(stmt)
|
|
358
|
+
return int(result.rowcount or 0)
|
|
359
|
+
|
|
360
|
+
async def delete_all(self) -> int:
|
|
361
|
+
"""Delete every row for this model. Dev/ops only — prefer ``delete_where`` in app code."""
|
|
362
|
+
result = await self.session.execute(delete(self.model))
|
|
363
|
+
return int(result.rowcount or 0)
|
|
364
|
+
|
|
365
|
+
async def batched_purge_ids(
|
|
366
|
+
self,
|
|
367
|
+
*,
|
|
368
|
+
criteria: list[object],
|
|
369
|
+
batch_size: int,
|
|
370
|
+
) -> int:
|
|
371
|
+
"""Delete rows matching ``criteria`` in ``batch_size`` chunks, committing each batch."""
|
|
372
|
+
pk_key = self._pk_column.key
|
|
373
|
+
total = 0
|
|
374
|
+
while True:
|
|
375
|
+
rows = await (
|
|
376
|
+
self.statement()
|
|
377
|
+
.select_columns(self._pk_column)
|
|
378
|
+
.where(*criteria)
|
|
379
|
+
.limit(batch_size)
|
|
380
|
+
.mappings()
|
|
381
|
+
.all()
|
|
382
|
+
)
|
|
383
|
+
ids = [row[pk_key] for row in rows]
|
|
384
|
+
if not ids:
|
|
385
|
+
break
|
|
386
|
+
total += await self.delete_many(ids)
|
|
387
|
+
await self.session.commit()
|
|
388
|
+
return total
|
|
389
|
+
|
|
390
|
+
def statement(self) -> AsyncStatementQueryBuilder[T]:
|
|
391
|
+
"""Return a fluent statement builder for reads on this model (default read path)."""
|
|
392
|
+
if self._factory is not None:
|
|
393
|
+
return self._factory.create_statement(self.model)
|
|
394
|
+
return AsyncSqlAlchemyStatementBuilder(self.session, self.model)
|
|
395
|
+
|
|
396
|
+
def for_repo[R](self, repo_class: type[R]) -> R:
|
|
397
|
+
"""Return a typed entity repository sharing this session and factory."""
|
|
398
|
+
if self._factory is None:
|
|
399
|
+
raise RuntimeError("for_repo() requires an AsyncRepositoryFactory")
|
|
400
|
+
return self._factory.get_repository(repo_class)
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
"""Request-scoped audit actor context."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from collections.abc import Iterator
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from contextvars import ContextVar
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclass(frozen=True)
|
|
11
|
+
class AuditContext:
|
|
12
|
+
actor_id: int | str | None = None
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
_audit_context: ContextVar[AuditContext | None] = ContextVar("audit_context", default=None)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_audit_context() -> AuditContext | None:
|
|
19
|
+
return _audit_context.get()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_audit_actor_id() -> int | str | None:
|
|
23
|
+
ctx = get_audit_context()
|
|
24
|
+
return ctx.actor_id if ctx is not None else None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def set_audit_context(ctx: AuditContext | None) -> None:
|
|
28
|
+
_audit_context.set(ctx)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@contextmanager
|
|
32
|
+
def audit_context(actor_id: int | str | None) -> Iterator[None]:
|
|
33
|
+
token = _audit_context.set(AuditContext(actor_id=actor_id))
|
|
34
|
+
try:
|
|
35
|
+
yield
|
|
36
|
+
finally:
|
|
37
|
+
_audit_context.reset(token)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
"""Logical audit column names for listener stamping."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass(frozen=True)
|
|
8
|
+
class AuditColumns:
|
|
9
|
+
created: str = "created_on"
|
|
10
|
+
updated: str = "updated_on"
|
|
11
|
+
created_by: str = "created_by_id"
|
|
12
|
+
updated_by: str = "updated_by_id"
|
|
13
|
+
deleted: str = "deleted_on"
|
|
14
|
+
deleted_by: str = "deleted_by_id"
|
|
15
|
+
|
|
16
|
+
@classmethod
|
|
17
|
+
def for_model(cls, model_type: type) -> AuditColumns:
|
|
18
|
+
return cls()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def is_audit_model(instance: object) -> bool:
|
|
22
|
+
from sqlphilosophy.audit.model import AuditMixin
|
|
23
|
+
|
|
24
|
+
return isinstance(instance, AuditMixin)
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
"""SQLAlchemy audit listeners gated on :class:`AuditMixin`."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from abc import ABC
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from datetime import timezone
|
|
7
|
+
from typing import Any
|
|
8
|
+
from sqlalchemy import event
|
|
9
|
+
from sqlalchemy.orm import Mapper
|
|
10
|
+
from sqlphilosophy.audit.context import get_audit_context
|
|
11
|
+
from sqlphilosophy.audit.fields import AuditColumns
|
|
12
|
+
from sqlphilosophy.audit.fields import is_audit_model
|
|
13
|
+
from sqlphilosophy.audit.model import AuditMixin
|
|
14
|
+
|
|
15
|
+
_ATTACHED = False
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AuditListener(ABC):
|
|
19
|
+
def now(self) -> datetime:
|
|
20
|
+
return datetime.now(timezone.utc)
|
|
21
|
+
|
|
22
|
+
def _actor(self) -> int | str | None:
|
|
23
|
+
ctx = get_audit_context()
|
|
24
|
+
return ctx.actor_id if ctx is not None else None
|
|
25
|
+
|
|
26
|
+
def _has_attr(self, target: object, name: str) -> bool:
|
|
27
|
+
return hasattr(type(target), name)
|
|
28
|
+
|
|
29
|
+
def _set_if_empty(self, target: object, name: str, value: object) -> None:
|
|
30
|
+
if not self._has_attr(target, name):
|
|
31
|
+
return
|
|
32
|
+
current = getattr(target, name)
|
|
33
|
+
if current is None or current == "":
|
|
34
|
+
setattr(target, name, value)
|
|
35
|
+
|
|
36
|
+
def _set(self, target: object, name: str, value: object) -> None:
|
|
37
|
+
if self._has_attr(target, name):
|
|
38
|
+
setattr(target, name, value)
|
|
39
|
+
|
|
40
|
+
def stamp_on_insert(self, target: AuditMixin) -> None:
|
|
41
|
+
fields = AuditColumns.for_model(type(target))
|
|
42
|
+
ts = self.now()
|
|
43
|
+
self._set_if_empty(target, fields.created, ts)
|
|
44
|
+
self._set_if_empty(target, fields.updated, ts)
|
|
45
|
+
actor = self._actor()
|
|
46
|
+
if actor is not None:
|
|
47
|
+
self._set_if_empty(target, fields.created_by, actor)
|
|
48
|
+
self._set_if_empty(target, fields.updated_by, actor)
|
|
49
|
+
|
|
50
|
+
def stamp_on_update(self, target: AuditMixin) -> None:
|
|
51
|
+
fields = AuditColumns.for_model(type(target))
|
|
52
|
+
if self._has_attr(target, fields.updated):
|
|
53
|
+
setattr(target, fields.updated, self.now())
|
|
54
|
+
actor = self._actor()
|
|
55
|
+
if actor is not None:
|
|
56
|
+
self._set(target, fields.updated_by, actor)
|
|
57
|
+
|
|
58
|
+
def stamp_on_soft_delete(self, target: AuditMixin, *, actor: int | str | None = None) -> None:
|
|
59
|
+
fields = AuditColumns.for_model(type(target))
|
|
60
|
+
self._set(target, fields.deleted, self.now())
|
|
61
|
+
resolved = actor if actor is not None else self._actor()
|
|
62
|
+
if resolved is not None:
|
|
63
|
+
self._set(target, fields.deleted_by, resolved)
|
|
64
|
+
|
|
65
|
+
def attach(self) -> None:
|
|
66
|
+
global _ATTACHED
|
|
67
|
+
if _ATTACHED:
|
|
68
|
+
return
|
|
69
|
+
listener = self
|
|
70
|
+
|
|
71
|
+
@event.listens_for(AuditMixin, "before_insert", propagate=True)
|
|
72
|
+
def _before_insert(mapper: Mapper[Any], connection: object, target: object) -> None:
|
|
73
|
+
if not is_audit_model(target):
|
|
74
|
+
return # pragma: no cover
|
|
75
|
+
listener.stamp_on_insert(target)
|
|
76
|
+
|
|
77
|
+
@event.listens_for(AuditMixin, "before_update", propagate=True)
|
|
78
|
+
def _before_update(mapper: Mapper[Any], connection: object, target: object) -> None:
|
|
79
|
+
if not is_audit_model(target):
|
|
80
|
+
return # pragma: no cover
|
|
81
|
+
listener.stamp_on_update(target)
|
|
82
|
+
|
|
83
|
+
_ATTACHED = True
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
_default_listener = AuditListener()
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def get_audit_listener() -> AuditListener:
|
|
90
|
+
return _default_listener
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def configure_audit_listeners() -> None:
|
|
94
|
+
_default_listener.attach()
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def soft_delete(target: AuditMixin, *, actor: int | str | None = None) -> None:
|
|
98
|
+
"""Stamp soft-delete columns via the configured audit listener."""
|
|
99
|
+
get_audit_listener().stamp_on_soft_delete(target, actor=actor)
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""Abstract mixins that gate audit listener processing."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from sqlalchemy import BigInteger
|
|
6
|
+
from sqlalchemy import DateTime
|
|
7
|
+
from sqlalchemy.orm import Mapped
|
|
8
|
+
from sqlalchemy.orm import mapped_column
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AuditMixin:
|
|
12
|
+
"""Base marker for audit listener dispatch."""
|
|
13
|
+
|
|
14
|
+
__abstract__ = True
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CreatedTimestampModel(AuditMixin):
|
|
18
|
+
"""Tables with ``created_on`` only (outbox, audit events, invites)."""
|
|
19
|
+
|
|
20
|
+
__abstract__ = True
|
|
21
|
+
|
|
22
|
+
created_on: Mapped[datetime] = mapped_column(DateTime(timezone=True))
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class UpdatedTimestampModel(AuditMixin):
|
|
26
|
+
"""Tables with ``updated_on`` only."""
|
|
27
|
+
|
|
28
|
+
__abstract__ = True
|
|
29
|
+
|
|
30
|
+
updated_on: Mapped[datetime] = mapped_column(DateTime(timezone=True))
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TimestampModel(AuditMixin):
|
|
34
|
+
"""Standard created/updated timestamp and actor audit columns."""
|
|
35
|
+
|
|
36
|
+
__abstract__ = True
|
|
37
|
+
|
|
38
|
+
created_on: Mapped[datetime] = mapped_column(DateTime(timezone=True))
|
|
39
|
+
updated_on: Mapped[datetime] = mapped_column(DateTime(timezone=True))
|
|
40
|
+
created_by_id: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
|
41
|
+
updated_by_id: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class SoftDeleteTimestampModel(TimestampModel):
|
|
45
|
+
"""Timestamped entities that support soft delete."""
|
|
46
|
+
|
|
47
|
+
__abstract__ = True
|
|
48
|
+
|
|
49
|
+
deleted_on: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
50
|
+
deleted_by_id: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class SoftDeleteModel(AuditMixin):
|
|
54
|
+
"""Soft-delete columns without created/updated timestamps (e.g. Profile)."""
|
|
55
|
+
|
|
56
|
+
__abstract__ = True
|
|
57
|
+
|
|
58
|
+
deleted_on: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
59
|
+
deleted_by_id: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
sqlphilosophy/py.typed
ADDED
|
File without changes
|