fastapi-toolsets 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,378 @@
1
+ """Generic async CRUD operations for SQLAlchemy models."""
2
+
3
+ from collections.abc import Sequence
4
+ from typing import Any, ClassVar, Generic, Self, TypeVar, cast
5
+
6
+ from pydantic import BaseModel
7
+ from sqlalchemy import and_, func, select
8
+ from sqlalchemy import delete as sql_delete
9
+ from sqlalchemy.dialects.postgresql import insert
10
+ from sqlalchemy.exc import NoResultFound
11
+ from sqlalchemy.ext.asyncio import AsyncSession
12
+ from sqlalchemy.orm import DeclarativeBase
13
+ from sqlalchemy.sql.roles import WhereHavingRole
14
+
15
+ from .db import get_transaction
16
+ from .exceptions import NotFoundError
17
+
18
+ __all__ = [
19
+ "AsyncCrud",
20
+ "CrudFactory",
21
+ ]
22
+
23
+ ModelType = TypeVar("ModelType", bound=DeclarativeBase)
24
+
25
+
26
+ class AsyncCrud(Generic[ModelType]):
27
+ """Generic async CRUD operations for SQLAlchemy models.
28
+
29
+ Subclass this and set the `model` class variable, or use `CrudFactory`.
30
+
31
+ Example:
32
+ class UserCrud(AsyncCrud[User]):
33
+ model = User
34
+
35
+ # Or use the factory:
36
+ UserCrud = CrudFactory(User)
37
+
38
+ # Then use it:
39
+ user = await UserCrud.get(session, [User.id == 1])
40
+ users = await UserCrud.get_multi(session, limit=10)
41
+ """
42
+
43
+ model: ClassVar[type[DeclarativeBase]]
44
+
45
+ @classmethod
46
+ async def create(
47
+ cls: type[Self],
48
+ session: AsyncSession,
49
+ obj: BaseModel,
50
+ ) -> ModelType:
51
+ """Create a new record in the database.
52
+
53
+ Args:
54
+ session: DB async session
55
+ obj: Pydantic model with data to create
56
+
57
+ Returns:
58
+ Created model instance
59
+ """
60
+ async with get_transaction(session):
61
+ db_model = cls.model(**obj.model_dump())
62
+ session.add(db_model)
63
+ await session.refresh(db_model)
64
+ return cast(ModelType, db_model)
65
+
66
+ @classmethod
67
+ async def get(
68
+ cls: type[Self],
69
+ session: AsyncSession,
70
+ filters: list[Any],
71
+ *,
72
+ with_for_update: bool = False,
73
+ load_options: list[Any] | None = None,
74
+ ) -> ModelType:
75
+ """Get exactly one record. Raises NotFoundError if not found.
76
+
77
+ Args:
78
+ session: DB async session
79
+ filters: List of SQLAlchemy filter conditions
80
+ with_for_update: Lock the row for update
81
+ load_options: SQLAlchemy loader options (e.g., selectinload)
82
+
83
+ Returns:
84
+ Model instance
85
+
86
+ Raises:
87
+ NotFoundError: If no record found
88
+ MultipleResultsFound: If more than one record found
89
+ """
90
+ q = select(cls.model).where(and_(*filters))
91
+ if load_options:
92
+ q = q.options(*load_options)
93
+ if with_for_update:
94
+ q = q.with_for_update()
95
+ result = await session.execute(q)
96
+ item = result.unique().scalar_one_or_none()
97
+ if not item:
98
+ raise NotFoundError()
99
+ return cast(ModelType, item)
100
+
101
+ @classmethod
102
+ async def first(
103
+ cls: type[Self],
104
+ session: AsyncSession,
105
+ filters: list[Any] | None = None,
106
+ *,
107
+ load_options: list[Any] | None = None,
108
+ ) -> ModelType | None:
109
+ """Get the first matching record, or None.
110
+
111
+ Args:
112
+ session: DB async session
113
+ filters: List of SQLAlchemy filter conditions
114
+ load_options: SQLAlchemy loader options
115
+
116
+ Returns:
117
+ Model instance or None
118
+ """
119
+ q = select(cls.model)
120
+ if filters:
121
+ q = q.where(and_(*filters))
122
+ if load_options:
123
+ q = q.options(*load_options)
124
+ result = await session.execute(q)
125
+ return cast(ModelType | None, result.unique().scalars().first())
126
+
127
+ @classmethod
128
+ async def get_multi(
129
+ cls: type[Self],
130
+ session: AsyncSession,
131
+ *,
132
+ filters: list[Any] | None = None,
133
+ load_options: list[Any] | None = None,
134
+ order_by: Any | None = None,
135
+ limit: int | None = None,
136
+ offset: int | None = None,
137
+ ) -> Sequence[ModelType]:
138
+ """Get multiple records from the database.
139
+
140
+ Args:
141
+ session: DB async session
142
+ filters: List of SQLAlchemy filter conditions
143
+ load_options: SQLAlchemy loader options
144
+ order_by: Column or list of columns to order by
145
+ limit: Max number of rows to return
146
+ offset: Rows to skip
147
+
148
+ Returns:
149
+ List of model instances
150
+ """
151
+ q = select(cls.model)
152
+ if filters:
153
+ q = q.where(and_(*filters))
154
+ if load_options:
155
+ q = q.options(*load_options)
156
+ if order_by is not None:
157
+ q = q.order_by(order_by)
158
+ if offset is not None:
159
+ q = q.offset(offset)
160
+ if limit is not None:
161
+ q = q.limit(limit)
162
+ result = await session.execute(q)
163
+ return cast(Sequence[ModelType], result.unique().scalars().all())
164
+
165
+ @classmethod
166
+ async def update(
167
+ cls: type[Self],
168
+ session: AsyncSession,
169
+ obj: BaseModel,
170
+ filters: list[Any],
171
+ *,
172
+ exclude_unset: bool = True,
173
+ exclude_none: bool = False,
174
+ ) -> ModelType:
175
+ """Update a record in the database.
176
+
177
+ Args:
178
+ session: DB async session
179
+ obj: Pydantic model with update data
180
+ filters: List of SQLAlchemy filter conditions
181
+ exclude_unset: Exclude fields not explicitly set in the schema
182
+ exclude_none: Exclude fields with None value
183
+
184
+ Returns:
185
+ Updated model instance
186
+
187
+ Raises:
188
+ NotFoundError: If no record found
189
+ """
190
+ async with get_transaction(session):
191
+ db_model = await cls.get(session=session, filters=filters)
192
+ values = obj.model_dump(
193
+ exclude_unset=exclude_unset, exclude_none=exclude_none
194
+ )
195
+ for key, value in values.items():
196
+ setattr(db_model, key, value)
197
+ await session.refresh(db_model)
198
+ return db_model
199
+
200
+ @classmethod
201
+ async def upsert(
202
+ cls: type[Self],
203
+ session: AsyncSession,
204
+ obj: BaseModel,
205
+ index_elements: list[str],
206
+ *,
207
+ set_: BaseModel | None = None,
208
+ where: WhereHavingRole | None = None,
209
+ ) -> ModelType | None:
210
+ """Create or update a record (PostgreSQL only).
211
+
212
+ Uses INSERT ... ON CONFLICT for atomic upsert.
213
+
214
+ Args:
215
+ session: DB async session
216
+ obj: Pydantic model with data
217
+ index_elements: Columns for ON CONFLICT (unique constraint)
218
+ set_: Pydantic model for ON CONFLICT DO UPDATE SET
219
+ where: WHERE clause for ON CONFLICT DO UPDATE
220
+
221
+ Returns:
222
+ Model instance
223
+ """
224
+ async with get_transaction(session):
225
+ values = obj.model_dump(exclude_unset=True)
226
+ q = insert(cls.model).values(**values)
227
+ if set_:
228
+ q = q.on_conflict_do_update(
229
+ index_elements=index_elements,
230
+ set_=set_.model_dump(exclude_unset=True),
231
+ where=where,
232
+ )
233
+ else:
234
+ q = q.on_conflict_do_nothing(index_elements=index_elements)
235
+ q = q.returning(cls.model)
236
+ result = await session.execute(q)
237
+ try:
238
+ db_model = result.unique().scalar_one()
239
+ except NoResultFound:
240
+ db_model = await cls.first(
241
+ session=session,
242
+ filters=[getattr(cls.model, k) == v for k, v in values.items()],
243
+ )
244
+ return cast(ModelType | None, db_model)
245
+
246
+ @classmethod
247
+ async def delete(
248
+ cls: type[Self],
249
+ session: AsyncSession,
250
+ filters: list[Any],
251
+ ) -> bool:
252
+ """Delete records from the database.
253
+
254
+ Args:
255
+ session: DB async session
256
+ filters: List of SQLAlchemy filter conditions
257
+
258
+ Returns:
259
+ True if deletion was executed
260
+ """
261
+ async with get_transaction(session):
262
+ q = sql_delete(cls.model).where(and_(*filters))
263
+ await session.execute(q)
264
+ return True
265
+
266
+ @classmethod
267
+ async def count(
268
+ cls: type[Self],
269
+ session: AsyncSession,
270
+ filters: list[Any] | None = None,
271
+ ) -> int:
272
+ """Count records matching the filters.
273
+
274
+ Args:
275
+ session: DB async session
276
+ filters: List of SQLAlchemy filter conditions
277
+
278
+ Returns:
279
+ Number of matching records
280
+ """
281
+ q = select(func.count()).select_from(cls.model)
282
+ if filters:
283
+ q = q.where(and_(*filters))
284
+ result = await session.execute(q)
285
+ return result.scalar_one()
286
+
287
+ @classmethod
288
+ async def exists(
289
+ cls: type[Self],
290
+ session: AsyncSession,
291
+ filters: list[Any],
292
+ ) -> bool:
293
+ """Check if a record exists.
294
+
295
+ Args:
296
+ session: DB async session
297
+ filters: List of SQLAlchemy filter conditions
298
+
299
+ Returns:
300
+ True if at least one record matches
301
+ """
302
+ q = select(cls.model).where(and_(*filters)).exists().select()
303
+ result = await session.execute(q)
304
+ return bool(result.scalar())
305
+
306
+ @classmethod
307
+ async def paginate(
308
+ cls: type[Self],
309
+ session: AsyncSession,
310
+ *,
311
+ filters: list[Any] | None = None,
312
+ load_options: list[Any] | None = None,
313
+ order_by: Any | None = None,
314
+ page: int = 1,
315
+ items_per_page: int = 20,
316
+ ) -> dict[str, Any]:
317
+ """Get paginated results with metadata.
318
+
319
+ Args:
320
+ session: DB async session
321
+ filters: List of SQLAlchemy filter conditions
322
+ load_options: SQLAlchemy loader options
323
+ order_by: Column or list of columns to order by
324
+ page: Page number (1-indexed)
325
+ items_per_page: Number of items per page
326
+
327
+ Returns:
328
+ Dict with 'data' and 'pagination' keys
329
+ """
330
+ filters = filters or []
331
+ offset = (page - 1) * items_per_page
332
+
333
+ items = await cls.get_multi(
334
+ session,
335
+ filters=filters,
336
+ load_options=load_options,
337
+ order_by=order_by,
338
+ limit=items_per_page,
339
+ offset=offset,
340
+ )
341
+
342
+ total_count = await cls.count(session, filters=filters)
343
+
344
+ return {
345
+ "data": items,
346
+ "pagination": {
347
+ "total_count": total_count,
348
+ "items_per_page": items_per_page,
349
+ "page": page,
350
+ "has_more": page * items_per_page < total_count,
351
+ },
352
+ }
353
+
354
+
355
+ def CrudFactory(
356
+ model: type[ModelType],
357
+ ) -> type[AsyncCrud[ModelType]]:
358
+ """Create a CRUD class for a specific model.
359
+
360
+ Args:
361
+ model: SQLAlchemy model class
362
+
363
+ Returns:
364
+ AsyncCrud subclass bound to the model
365
+
366
+ Example:
367
+ from fastapi_toolsets.crud import CrudFactory
368
+ from myapp.models import User, Post
369
+
370
+ UserCrud = CrudFactory(User)
371
+ PostCrud = CrudFactory(Post)
372
+
373
+ # Usage
374
+ user = await UserCrud.get(session, [User.id == 1])
375
+ posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
376
+ """
377
+ cls = type(f"Async{model.__name__}Crud", (AsyncCrud,), {"model": model})
378
+ return cast(type[AsyncCrud[ModelType]], cls)
fastapi_toolsets/db.py ADDED
@@ -0,0 +1,175 @@
1
+ """Database utilities: sessions, transactions, and locks."""
2
+
3
+ from collections.abc import AsyncGenerator, Callable
4
+ from contextlib import AbstractAsyncContextManager, asynccontextmanager
5
+ from enum import Enum
6
+
7
+ from sqlalchemy import text
8
+ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
9
+ from sqlalchemy.orm import DeclarativeBase
10
+
11
+ __all__ = [
12
+ "LockMode",
13
+ "create_db_context",
14
+ "create_db_dependency",
15
+ "lock_tables",
16
+ "get_transaction",
17
+ ]
18
+
19
+
20
+ def create_db_dependency(
21
+ session_maker: async_sessionmaker[AsyncSession],
22
+ ) -> Callable[[], AsyncGenerator[AsyncSession, None]]:
23
+ """Create a FastAPI dependency for database sessions.
24
+
25
+ Creates a dependency function that yields a session and auto-commits
26
+ if a transaction is active when the request completes.
27
+
28
+ Args:
29
+ session_maker: Async session factory from create_session_factory()
30
+
31
+ Returns:
32
+ An async generator function usable with FastAPI's Depends()
33
+
34
+ Example:
35
+ from fastapi import Depends
36
+ from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
37
+ from fastapi_toolsets.db import create_db_dependency
38
+
39
+ engine = create_async_engine("postgresql+asyncpg://...")
40
+ SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
41
+ get_db = create_db_dependency(SessionLocal)
42
+
43
+ @app.get("/users")
44
+ async def list_users(session: AsyncSession = Depends(get_db)):
45
+ ...
46
+ """
47
+
48
+ async def get_db() -> AsyncGenerator[AsyncSession, None]:
49
+ async with session_maker() as session:
50
+ yield session
51
+ if session.in_transaction():
52
+ await session.commit()
53
+
54
+ return get_db
55
+
56
+
57
+ def create_db_context(
58
+ session_maker: async_sessionmaker[AsyncSession],
59
+ ) -> Callable[[], AbstractAsyncContextManager[AsyncSession]]:
60
+ """Create a context manager for database sessions.
61
+
62
+ Creates a context manager for use outside of FastAPI request handlers,
63
+ such as in background tasks, CLI commands, or tests.
64
+
65
+ Args:
66
+ session_maker: Async session factory from create_session_factory()
67
+
68
+ Returns:
69
+ An async context manager function
70
+
71
+ Example:
72
+ from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
73
+ from fastapi_toolsets.db import create_db_context
74
+
75
+ engine = create_async_engine("postgresql+asyncpg://...")
76
+ SessionLocal = async_sessionmaker(engine, expire_on_commit=False)
77
+ get_db_context = create_db_context(SessionLocal)
78
+
79
+ async def background_task():
80
+ async with get_db_context() as session:
81
+ user = await UserCrud.get(session, [User.id == 1])
82
+ ...
83
+ """
84
+ get_db = create_db_dependency(session_maker)
85
+ return asynccontextmanager(get_db)
86
+
87
+
88
+ @asynccontextmanager
89
+ async def get_transaction(
90
+ session: AsyncSession,
91
+ ) -> AsyncGenerator[AsyncSession, None]:
92
+ """Get a transaction context, handling nested transactions.
93
+
94
+ If already in a transaction, creates a savepoint (nested transaction).
95
+ Otherwise, starts a new transaction.
96
+
97
+ Args:
98
+ session: AsyncSession instance
99
+
100
+ Yields:
101
+ The session within the transaction context
102
+
103
+ Example:
104
+ async with get_transaction(session):
105
+ session.add(model)
106
+ # Auto-commits on exit, rolls back on exception
107
+ """
108
+ if session.in_transaction():
109
+ async with session.begin_nested():
110
+ yield session
111
+ else:
112
+ async with session.begin():
113
+ yield session
114
+
115
+
116
+ class LockMode(str, Enum):
117
+ """PostgreSQL table lock modes.
118
+
119
+ See: https://www.postgresql.org/docs/current/explicit-locking.html
120
+ """
121
+
122
+ ACCESS_SHARE = "ACCESS SHARE"
123
+ ROW_SHARE = "ROW SHARE"
124
+ ROW_EXCLUSIVE = "ROW EXCLUSIVE"
125
+ SHARE_UPDATE_EXCLUSIVE = "SHARE UPDATE EXCLUSIVE"
126
+ SHARE = "SHARE"
127
+ SHARE_ROW_EXCLUSIVE = "SHARE ROW EXCLUSIVE"
128
+ EXCLUSIVE = "EXCLUSIVE"
129
+ ACCESS_EXCLUSIVE = "ACCESS EXCLUSIVE"
130
+
131
+
132
+ @asynccontextmanager
133
+ async def lock_tables(
134
+ session: AsyncSession,
135
+ tables: list[type[DeclarativeBase]],
136
+ *,
137
+ mode: LockMode = LockMode.SHARE_UPDATE_EXCLUSIVE,
138
+ timeout: str = "5s",
139
+ ) -> AsyncGenerator[AsyncSession, None]:
140
+ """Lock PostgreSQL tables for the duration of a transaction.
141
+
142
+ Acquires table-level locks that are held until the transaction ends.
143
+ Useful for preventing concurrent modifications during critical operations.
144
+
145
+ Args:
146
+ session: AsyncSession instance
147
+ tables: List of SQLAlchemy model classes to lock
148
+ mode: Lock mode (default: SHARE UPDATE EXCLUSIVE)
149
+ timeout: Lock timeout (default: "5s")
150
+
151
+ Yields:
152
+ The session with locked tables
153
+
154
+ Raises:
155
+ SQLAlchemyError: If lock cannot be acquired within timeout
156
+
157
+ Example:
158
+ from fastapi_toolsets.db import lock_tables, LockMode
159
+
160
+ async with lock_tables(session, [User, Account]):
161
+ # Tables are locked with SHARE UPDATE EXCLUSIVE mode
162
+ user = await UserCrud.get(session, [User.id == 1])
163
+ user.balance += 100
164
+
165
+ # With custom lock mode
166
+ async with lock_tables(session, [Order], mode=LockMode.EXCLUSIVE):
167
+ # Exclusive lock - no other transactions can access
168
+ await process_order(session, order_id)
169
+ """
170
+ table_names = ",".join(table.__tablename__ for table in tables)
171
+
172
+ async with get_transaction(session):
173
+ await session.execute(text(f"SET LOCAL lock_timeout='{timeout}'"))
174
+ await session.execute(text(f"LOCK {table_names} IN {mode.value} MODE"))
175
+ yield session
@@ -0,0 +1,19 @@
1
+ from .exceptions import (
2
+ ApiException,
3
+ ConflictError,
4
+ ForbiddenError,
5
+ NotFoundError,
6
+ UnauthorizedError,
7
+ generate_error_responses,
8
+ )
9
+ from .handler import init_exceptions_handlers
10
+
11
+ __all__ = [
12
+ "init_exceptions_handlers",
13
+ "generate_error_responses",
14
+ "ApiException",
15
+ "ConflictError",
16
+ "ForbiddenError",
17
+ "NotFoundError",
18
+ "UnauthorizedError",
19
+ ]