fastapi-toolsets 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastapi_toolsets/__init__.py +24 -0
- fastapi_toolsets/cli/__init__.py +5 -0
- fastapi_toolsets/cli/app.py +97 -0
- fastapi_toolsets/cli/commands/__init__.py +1 -0
- fastapi_toolsets/cli/commands/fixtures.py +225 -0
- fastapi_toolsets/crud.py +378 -0
- fastapi_toolsets/db.py +175 -0
- fastapi_toolsets/exceptions/__init__.py +19 -0
- fastapi_toolsets/exceptions/exceptions.py +166 -0
- fastapi_toolsets/exceptions/handler.py +169 -0
- fastapi_toolsets/fixtures/__init__.py +17 -0
- fastapi_toolsets/fixtures/fixtures.py +321 -0
- fastapi_toolsets/fixtures/pytest_plugin.py +204 -0
- fastapi_toolsets/py.typed +0 -0
- fastapi_toolsets/schemas.py +116 -0
- fastapi_toolsets-0.1.0.dist-info/METADATA +89 -0
- fastapi_toolsets-0.1.0.dist-info/RECORD +20 -0
- fastapi_toolsets-0.1.0.dist-info/WHEEL +4 -0
- fastapi_toolsets-0.1.0.dist-info/entry_points.txt +3 -0
- fastapi_toolsets-0.1.0.dist-info/licenses/LICENSE +21 -0
fastapi_toolsets/crud.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
1
|
+
"""Generic async CRUD operations for SQLAlchemy models."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import Sequence
|
|
4
|
+
from typing import Any, ClassVar, Generic, Self, TypeVar, cast
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel
|
|
7
|
+
from sqlalchemy import and_, func, select
|
|
8
|
+
from sqlalchemy import delete as sql_delete
|
|
9
|
+
from sqlalchemy.dialects.postgresql import insert
|
|
10
|
+
from sqlalchemy.exc import NoResultFound
|
|
11
|
+
from sqlalchemy.ext.asyncio import AsyncSession
|
|
12
|
+
from sqlalchemy.orm import DeclarativeBase
|
|
13
|
+
from sqlalchemy.sql.roles import WhereHavingRole
|
|
14
|
+
|
|
15
|
+
from .db import get_transaction
|
|
16
|
+
from .exceptions import NotFoundError
|
|
17
|
+
|
|
18
|
+
__all__ = [
|
|
19
|
+
"AsyncCrud",
|
|
20
|
+
"CrudFactory",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
ModelType = TypeVar("ModelType", bound=DeclarativeBase)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class AsyncCrud(Generic[ModelType]):
|
|
27
|
+
"""Generic async CRUD operations for SQLAlchemy models.
|
|
28
|
+
|
|
29
|
+
Subclass this and set the `model` class variable, or use `CrudFactory`.
|
|
30
|
+
|
|
31
|
+
Example:
|
|
32
|
+
class UserCrud(AsyncCrud[User]):
|
|
33
|
+
model = User
|
|
34
|
+
|
|
35
|
+
# Or use the factory:
|
|
36
|
+
UserCrud = CrudFactory(User)
|
|
37
|
+
|
|
38
|
+
# Then use it:
|
|
39
|
+
user = await UserCrud.get(session, [User.id == 1])
|
|
40
|
+
users = await UserCrud.get_multi(session, limit=10)
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
model: ClassVar[type[DeclarativeBase]]
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
async def create(
|
|
47
|
+
cls: type[Self],
|
|
48
|
+
session: AsyncSession,
|
|
49
|
+
obj: BaseModel,
|
|
50
|
+
) -> ModelType:
|
|
51
|
+
"""Create a new record in the database.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
session: DB async session
|
|
55
|
+
obj: Pydantic model with data to create
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
Created model instance
|
|
59
|
+
"""
|
|
60
|
+
async with get_transaction(session):
|
|
61
|
+
db_model = cls.model(**obj.model_dump())
|
|
62
|
+
session.add(db_model)
|
|
63
|
+
await session.refresh(db_model)
|
|
64
|
+
return cast(ModelType, db_model)
|
|
65
|
+
|
|
66
|
+
@classmethod
|
|
67
|
+
async def get(
|
|
68
|
+
cls: type[Self],
|
|
69
|
+
session: AsyncSession,
|
|
70
|
+
filters: list[Any],
|
|
71
|
+
*,
|
|
72
|
+
with_for_update: bool = False,
|
|
73
|
+
load_options: list[Any] | None = None,
|
|
74
|
+
) -> ModelType:
|
|
75
|
+
"""Get exactly one record. Raises NotFoundError if not found.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
session: DB async session
|
|
79
|
+
filters: List of SQLAlchemy filter conditions
|
|
80
|
+
with_for_update: Lock the row for update
|
|
81
|
+
load_options: SQLAlchemy loader options (e.g., selectinload)
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Model instance
|
|
85
|
+
|
|
86
|
+
Raises:
|
|
87
|
+
NotFoundError: If no record found
|
|
88
|
+
MultipleResultsFound: If more than one record found
|
|
89
|
+
"""
|
|
90
|
+
q = select(cls.model).where(and_(*filters))
|
|
91
|
+
if load_options:
|
|
92
|
+
q = q.options(*load_options)
|
|
93
|
+
if with_for_update:
|
|
94
|
+
q = q.with_for_update()
|
|
95
|
+
result = await session.execute(q)
|
|
96
|
+
item = result.unique().scalar_one_or_none()
|
|
97
|
+
if not item:
|
|
98
|
+
raise NotFoundError()
|
|
99
|
+
return cast(ModelType, item)
|
|
100
|
+
|
|
101
|
+
@classmethod
|
|
102
|
+
async def first(
|
|
103
|
+
cls: type[Self],
|
|
104
|
+
session: AsyncSession,
|
|
105
|
+
filters: list[Any] | None = None,
|
|
106
|
+
*,
|
|
107
|
+
load_options: list[Any] | None = None,
|
|
108
|
+
) -> ModelType | None:
|
|
109
|
+
"""Get the first matching record, or None.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
session: DB async session
|
|
113
|
+
filters: List of SQLAlchemy filter conditions
|
|
114
|
+
load_options: SQLAlchemy loader options
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Model instance or None
|
|
118
|
+
"""
|
|
119
|
+
q = select(cls.model)
|
|
120
|
+
if filters:
|
|
121
|
+
q = q.where(and_(*filters))
|
|
122
|
+
if load_options:
|
|
123
|
+
q = q.options(*load_options)
|
|
124
|
+
result = await session.execute(q)
|
|
125
|
+
return cast(ModelType | None, result.unique().scalars().first())
|
|
126
|
+
|
|
127
|
+
@classmethod
|
|
128
|
+
async def get_multi(
|
|
129
|
+
cls: type[Self],
|
|
130
|
+
session: AsyncSession,
|
|
131
|
+
*,
|
|
132
|
+
filters: list[Any] | None = None,
|
|
133
|
+
load_options: list[Any] | None = None,
|
|
134
|
+
order_by: Any | None = None,
|
|
135
|
+
limit: int | None = None,
|
|
136
|
+
offset: int | None = None,
|
|
137
|
+
) -> Sequence[ModelType]:
|
|
138
|
+
"""Get multiple records from the database.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
session: DB async session
|
|
142
|
+
filters: List of SQLAlchemy filter conditions
|
|
143
|
+
load_options: SQLAlchemy loader options
|
|
144
|
+
order_by: Column or list of columns to order by
|
|
145
|
+
limit: Max number of rows to return
|
|
146
|
+
offset: Rows to skip
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
List of model instances
|
|
150
|
+
"""
|
|
151
|
+
q = select(cls.model)
|
|
152
|
+
if filters:
|
|
153
|
+
q = q.where(and_(*filters))
|
|
154
|
+
if load_options:
|
|
155
|
+
q = q.options(*load_options)
|
|
156
|
+
if order_by is not None:
|
|
157
|
+
q = q.order_by(order_by)
|
|
158
|
+
if offset is not None:
|
|
159
|
+
q = q.offset(offset)
|
|
160
|
+
if limit is not None:
|
|
161
|
+
q = q.limit(limit)
|
|
162
|
+
result = await session.execute(q)
|
|
163
|
+
return cast(Sequence[ModelType], result.unique().scalars().all())
|
|
164
|
+
|
|
165
|
+
@classmethod
|
|
166
|
+
async def update(
|
|
167
|
+
cls: type[Self],
|
|
168
|
+
session: AsyncSession,
|
|
169
|
+
obj: BaseModel,
|
|
170
|
+
filters: list[Any],
|
|
171
|
+
*,
|
|
172
|
+
exclude_unset: bool = True,
|
|
173
|
+
exclude_none: bool = False,
|
|
174
|
+
) -> ModelType:
|
|
175
|
+
"""Update a record in the database.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
session: DB async session
|
|
179
|
+
obj: Pydantic model with update data
|
|
180
|
+
filters: List of SQLAlchemy filter conditions
|
|
181
|
+
exclude_unset: Exclude fields not explicitly set in the schema
|
|
182
|
+
exclude_none: Exclude fields with None value
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
Updated model instance
|
|
186
|
+
|
|
187
|
+
Raises:
|
|
188
|
+
NotFoundError: If no record found
|
|
189
|
+
"""
|
|
190
|
+
async with get_transaction(session):
|
|
191
|
+
db_model = await cls.get(session=session, filters=filters)
|
|
192
|
+
values = obj.model_dump(
|
|
193
|
+
exclude_unset=exclude_unset, exclude_none=exclude_none
|
|
194
|
+
)
|
|
195
|
+
for key, value in values.items():
|
|
196
|
+
setattr(db_model, key, value)
|
|
197
|
+
await session.refresh(db_model)
|
|
198
|
+
return db_model
|
|
199
|
+
|
|
200
|
+
@classmethod
|
|
201
|
+
async def upsert(
|
|
202
|
+
cls: type[Self],
|
|
203
|
+
session: AsyncSession,
|
|
204
|
+
obj: BaseModel,
|
|
205
|
+
index_elements: list[str],
|
|
206
|
+
*,
|
|
207
|
+
set_: BaseModel | None = None,
|
|
208
|
+
where: WhereHavingRole | None = None,
|
|
209
|
+
) -> ModelType | None:
|
|
210
|
+
"""Create or update a record (PostgreSQL only).
|
|
211
|
+
|
|
212
|
+
Uses INSERT ... ON CONFLICT for atomic upsert.
|
|
213
|
+
|
|
214
|
+
Args:
|
|
215
|
+
session: DB async session
|
|
216
|
+
obj: Pydantic model with data
|
|
217
|
+
index_elements: Columns for ON CONFLICT (unique constraint)
|
|
218
|
+
set_: Pydantic model for ON CONFLICT DO UPDATE SET
|
|
219
|
+
where: WHERE clause for ON CONFLICT DO UPDATE
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
Model instance
|
|
223
|
+
"""
|
|
224
|
+
async with get_transaction(session):
|
|
225
|
+
values = obj.model_dump(exclude_unset=True)
|
|
226
|
+
q = insert(cls.model).values(**values)
|
|
227
|
+
if set_:
|
|
228
|
+
q = q.on_conflict_do_update(
|
|
229
|
+
index_elements=index_elements,
|
|
230
|
+
set_=set_.model_dump(exclude_unset=True),
|
|
231
|
+
where=where,
|
|
232
|
+
)
|
|
233
|
+
else:
|
|
234
|
+
q = q.on_conflict_do_nothing(index_elements=index_elements)
|
|
235
|
+
q = q.returning(cls.model)
|
|
236
|
+
result = await session.execute(q)
|
|
237
|
+
try:
|
|
238
|
+
db_model = result.unique().scalar_one()
|
|
239
|
+
except NoResultFound:
|
|
240
|
+
db_model = await cls.first(
|
|
241
|
+
session=session,
|
|
242
|
+
filters=[getattr(cls.model, k) == v for k, v in values.items()],
|
|
243
|
+
)
|
|
244
|
+
return cast(ModelType | None, db_model)
|
|
245
|
+
|
|
246
|
+
@classmethod
|
|
247
|
+
async def delete(
|
|
248
|
+
cls: type[Self],
|
|
249
|
+
session: AsyncSession,
|
|
250
|
+
filters: list[Any],
|
|
251
|
+
) -> bool:
|
|
252
|
+
"""Delete records from the database.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
session: DB async session
|
|
256
|
+
filters: List of SQLAlchemy filter conditions
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
True if deletion was executed
|
|
260
|
+
"""
|
|
261
|
+
async with get_transaction(session):
|
|
262
|
+
q = sql_delete(cls.model).where(and_(*filters))
|
|
263
|
+
await session.execute(q)
|
|
264
|
+
return True
|
|
265
|
+
|
|
266
|
+
@classmethod
|
|
267
|
+
async def count(
|
|
268
|
+
cls: type[Self],
|
|
269
|
+
session: AsyncSession,
|
|
270
|
+
filters: list[Any] | None = None,
|
|
271
|
+
) -> int:
|
|
272
|
+
"""Count records matching the filters.
|
|
273
|
+
|
|
274
|
+
Args:
|
|
275
|
+
session: DB async session
|
|
276
|
+
filters: List of SQLAlchemy filter conditions
|
|
277
|
+
|
|
278
|
+
Returns:
|
|
279
|
+
Number of matching records
|
|
280
|
+
"""
|
|
281
|
+
q = select(func.count()).select_from(cls.model)
|
|
282
|
+
if filters:
|
|
283
|
+
q = q.where(and_(*filters))
|
|
284
|
+
result = await session.execute(q)
|
|
285
|
+
return result.scalar_one()
|
|
286
|
+
|
|
287
|
+
@classmethod
|
|
288
|
+
async def exists(
|
|
289
|
+
cls: type[Self],
|
|
290
|
+
session: AsyncSession,
|
|
291
|
+
filters: list[Any],
|
|
292
|
+
) -> bool:
|
|
293
|
+
"""Check if a record exists.
|
|
294
|
+
|
|
295
|
+
Args:
|
|
296
|
+
session: DB async session
|
|
297
|
+
filters: List of SQLAlchemy filter conditions
|
|
298
|
+
|
|
299
|
+
Returns:
|
|
300
|
+
True if at least one record matches
|
|
301
|
+
"""
|
|
302
|
+
q = select(cls.model).where(and_(*filters)).exists().select()
|
|
303
|
+
result = await session.execute(q)
|
|
304
|
+
return bool(result.scalar())
|
|
305
|
+
|
|
306
|
+
@classmethod
|
|
307
|
+
async def paginate(
|
|
308
|
+
cls: type[Self],
|
|
309
|
+
session: AsyncSession,
|
|
310
|
+
*,
|
|
311
|
+
filters: list[Any] | None = None,
|
|
312
|
+
load_options: list[Any] | None = None,
|
|
313
|
+
order_by: Any | None = None,
|
|
314
|
+
page: int = 1,
|
|
315
|
+
items_per_page: int = 20,
|
|
316
|
+
) -> dict[str, Any]:
|
|
317
|
+
"""Get paginated results with metadata.
|
|
318
|
+
|
|
319
|
+
Args:
|
|
320
|
+
session: DB async session
|
|
321
|
+
filters: List of SQLAlchemy filter conditions
|
|
322
|
+
load_options: SQLAlchemy loader options
|
|
323
|
+
order_by: Column or list of columns to order by
|
|
324
|
+
page: Page number (1-indexed)
|
|
325
|
+
items_per_page: Number of items per page
|
|
326
|
+
|
|
327
|
+
Returns:
|
|
328
|
+
Dict with 'data' and 'pagination' keys
|
|
329
|
+
"""
|
|
330
|
+
filters = filters or []
|
|
331
|
+
offset = (page - 1) * items_per_page
|
|
332
|
+
|
|
333
|
+
items = await cls.get_multi(
|
|
334
|
+
session,
|
|
335
|
+
filters=filters,
|
|
336
|
+
load_options=load_options,
|
|
337
|
+
order_by=order_by,
|
|
338
|
+
limit=items_per_page,
|
|
339
|
+
offset=offset,
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
total_count = await cls.count(session, filters=filters)
|
|
343
|
+
|
|
344
|
+
return {
|
|
345
|
+
"data": items,
|
|
346
|
+
"pagination": {
|
|
347
|
+
"total_count": total_count,
|
|
348
|
+
"items_per_page": items_per_page,
|
|
349
|
+
"page": page,
|
|
350
|
+
"has_more": page * items_per_page < total_count,
|
|
351
|
+
},
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def CrudFactory(
|
|
356
|
+
model: type[ModelType],
|
|
357
|
+
) -> type[AsyncCrud[ModelType]]:
|
|
358
|
+
"""Create a CRUD class for a specific model.
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
model: SQLAlchemy model class
|
|
362
|
+
|
|
363
|
+
Returns:
|
|
364
|
+
AsyncCrud subclass bound to the model
|
|
365
|
+
|
|
366
|
+
Example:
|
|
367
|
+
from fastapi_toolsets.crud import CrudFactory
|
|
368
|
+
from myapp.models import User, Post
|
|
369
|
+
|
|
370
|
+
UserCrud = CrudFactory(User)
|
|
371
|
+
PostCrud = CrudFactory(Post)
|
|
372
|
+
|
|
373
|
+
# Usage
|
|
374
|
+
user = await UserCrud.get(session, [User.id == 1])
|
|
375
|
+
posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
|
|
376
|
+
"""
|
|
377
|
+
cls = type(f"Async{model.__name__}Crud", (AsyncCrud,), {"model": model})
|
|
378
|
+
return cast(type[AsyncCrud[ModelType]], cls)
|
fastapi_toolsets/db.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""Database utilities: sessions, transactions, and locks."""
|
|
2
|
+
|
|
3
|
+
from collections.abc import AsyncGenerator, Callable
|
|
4
|
+
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
|
5
|
+
from enum import Enum
|
|
6
|
+
|
|
7
|
+
from sqlalchemy import text
|
|
8
|
+
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
9
|
+
from sqlalchemy.orm import DeclarativeBase
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"LockMode",
|
|
13
|
+
"create_db_context",
|
|
14
|
+
"create_db_dependency",
|
|
15
|
+
"lock_tables",
|
|
16
|
+
"get_transaction",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def create_db_dependency(
|
|
21
|
+
session_maker: async_sessionmaker[AsyncSession],
|
|
22
|
+
) -> Callable[[], AsyncGenerator[AsyncSession, None]]:
|
|
23
|
+
"""Create a FastAPI dependency for database sessions.
|
|
24
|
+
|
|
25
|
+
Creates a dependency function that yields a session and auto-commits
|
|
26
|
+
if a transaction is active when the request completes.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
session_maker: Async session factory from create_session_factory()
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
An async generator function usable with FastAPI's Depends()
|
|
33
|
+
|
|
34
|
+
Example:
|
|
35
|
+
from fastapi import Depends
|
|
36
|
+
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
|
37
|
+
from fastapi_toolsets.db import create_db_dependency
|
|
38
|
+
|
|
39
|
+
engine = create_async_engine("postgresql+asyncpg://...")
|
|
40
|
+
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
|
|
41
|
+
get_db = create_db_dependency(SessionLocal)
|
|
42
|
+
|
|
43
|
+
@app.get("/users")
|
|
44
|
+
async def list_users(session: AsyncSession = Depends(get_db)):
|
|
45
|
+
...
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
|
49
|
+
async with session_maker() as session:
|
|
50
|
+
yield session
|
|
51
|
+
if session.in_transaction():
|
|
52
|
+
await session.commit()
|
|
53
|
+
|
|
54
|
+
return get_db
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def create_db_context(
|
|
58
|
+
session_maker: async_sessionmaker[AsyncSession],
|
|
59
|
+
) -> Callable[[], AbstractAsyncContextManager[AsyncSession]]:
|
|
60
|
+
"""Create a context manager for database sessions.
|
|
61
|
+
|
|
62
|
+
Creates a context manager for use outside of FastAPI request handlers,
|
|
63
|
+
such as in background tasks, CLI commands, or tests.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
session_maker: Async session factory from create_session_factory()
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
An async context manager function
|
|
70
|
+
|
|
71
|
+
Example:
|
|
72
|
+
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
|
|
73
|
+
from fastapi_toolsets.db import create_db_context
|
|
74
|
+
|
|
75
|
+
engine = create_async_engine("postgresql+asyncpg://...")
|
|
76
|
+
SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
|
|
77
|
+
get_db_context = create_db_context(SessionLocal)
|
|
78
|
+
|
|
79
|
+
async def background_task():
|
|
80
|
+
async with get_db_context() as session:
|
|
81
|
+
user = await UserCrud.get(session, [User.id == 1])
|
|
82
|
+
...
|
|
83
|
+
"""
|
|
84
|
+
get_db = create_db_dependency(session_maker)
|
|
85
|
+
return asynccontextmanager(get_db)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@asynccontextmanager
|
|
89
|
+
async def get_transaction(
|
|
90
|
+
session: AsyncSession,
|
|
91
|
+
) -> AsyncGenerator[AsyncSession, None]:
|
|
92
|
+
"""Get a transaction context, handling nested transactions.
|
|
93
|
+
|
|
94
|
+
If already in a transaction, creates a savepoint (nested transaction).
|
|
95
|
+
Otherwise, starts a new transaction.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
session: AsyncSession instance
|
|
99
|
+
|
|
100
|
+
Yields:
|
|
101
|
+
The session within the transaction context
|
|
102
|
+
|
|
103
|
+
Example:
|
|
104
|
+
async with get_transaction(session):
|
|
105
|
+
session.add(model)
|
|
106
|
+
# Auto-commits on exit, rolls back on exception
|
|
107
|
+
"""
|
|
108
|
+
if session.in_transaction():
|
|
109
|
+
async with session.begin_nested():
|
|
110
|
+
yield session
|
|
111
|
+
else:
|
|
112
|
+
async with session.begin():
|
|
113
|
+
yield session
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class LockMode(str, Enum):
|
|
117
|
+
"""PostgreSQL table lock modes.
|
|
118
|
+
|
|
119
|
+
See: https://www.postgresql.org/docs/current/explicit-locking.html
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
ACCESS_SHARE = "ACCESS SHARE"
|
|
123
|
+
ROW_SHARE = "ROW SHARE"
|
|
124
|
+
ROW_EXCLUSIVE = "ROW EXCLUSIVE"
|
|
125
|
+
SHARE_UPDATE_EXCLUSIVE = "SHARE UPDATE EXCLUSIVE"
|
|
126
|
+
SHARE = "SHARE"
|
|
127
|
+
SHARE_ROW_EXCLUSIVE = "SHARE ROW EXCLUSIVE"
|
|
128
|
+
EXCLUSIVE = "EXCLUSIVE"
|
|
129
|
+
ACCESS_EXCLUSIVE = "ACCESS EXCLUSIVE"
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@asynccontextmanager
|
|
133
|
+
async def lock_tables(
|
|
134
|
+
session: AsyncSession,
|
|
135
|
+
tables: list[type[DeclarativeBase]],
|
|
136
|
+
*,
|
|
137
|
+
mode: LockMode = LockMode.SHARE_UPDATE_EXCLUSIVE,
|
|
138
|
+
timeout: str = "5s",
|
|
139
|
+
) -> AsyncGenerator[AsyncSession, None]:
|
|
140
|
+
"""Lock PostgreSQL tables for the duration of a transaction.
|
|
141
|
+
|
|
142
|
+
Acquires table-level locks that are held until the transaction ends.
|
|
143
|
+
Useful for preventing concurrent modifications during critical operations.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
session: AsyncSession instance
|
|
147
|
+
tables: List of SQLAlchemy model classes to lock
|
|
148
|
+
mode: Lock mode (default: SHARE UPDATE EXCLUSIVE)
|
|
149
|
+
timeout: Lock timeout (default: "5s")
|
|
150
|
+
|
|
151
|
+
Yields:
|
|
152
|
+
The session with locked tables
|
|
153
|
+
|
|
154
|
+
Raises:
|
|
155
|
+
SQLAlchemyError: If lock cannot be acquired within timeout
|
|
156
|
+
|
|
157
|
+
Example:
|
|
158
|
+
from fastapi_toolsets.db import lock_tables, LockMode
|
|
159
|
+
|
|
160
|
+
async with lock_tables(session, [User, Account]):
|
|
161
|
+
# Tables are locked with SHARE UPDATE EXCLUSIVE mode
|
|
162
|
+
user = await UserCrud.get(session, [User.id == 1])
|
|
163
|
+
user.balance += 100
|
|
164
|
+
|
|
165
|
+
# With custom lock mode
|
|
166
|
+
async with lock_tables(session, [Order], mode=LockMode.EXCLUSIVE):
|
|
167
|
+
# Exclusive lock - no other transactions can access
|
|
168
|
+
await process_order(session, order_id)
|
|
169
|
+
"""
|
|
170
|
+
table_names = ",".join(table.__tablename__ for table in tables)
|
|
171
|
+
|
|
172
|
+
async with get_transaction(session):
|
|
173
|
+
await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
|
|
174
|
+
await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
|
|
175
|
+
yield session
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from .exceptions import (
|
|
2
|
+
ApiException,
|
|
3
|
+
ConflictError,
|
|
4
|
+
ForbiddenError,
|
|
5
|
+
NotFoundError,
|
|
6
|
+
UnauthorizedError,
|
|
7
|
+
generate_error_responses,
|
|
8
|
+
)
|
|
9
|
+
from .handler import init_exceptions_handlers
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"init_exceptions_handlers",
|
|
13
|
+
"generate_error_responses",
|
|
14
|
+
"ApiException",
|
|
15
|
+
"ConflictError",
|
|
16
|
+
"ForbiddenError",
|
|
17
|
+
"NotFoundError",
|
|
18
|
+
"UnauthorizedError",
|
|
19
|
+
]
|