fastapi-toolsets 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -21,4 +21,4 @@ Example usage:
21
21
  return Response(data={"user": user.username}, message="Success")
22
22
  """
23
23
 
24
- __version__ = "0.2.0"
24
+ __version__ = "0.4.0"
@@ -0,0 +1,17 @@
1
+ """Generic async CRUD operations for SQLAlchemy models."""
2
+
3
+ from ..exceptions import NoSearchableFieldsError
4
+ from .factory import CrudFactory
5
+ from .search import (
6
+ SearchConfig,
7
+ SearchFieldType,
8
+ get_searchable_fields,
9
+ )
10
+
11
+ __all__ = [
12
+ "CrudFactory",
13
+ "get_searchable_fields",
14
+ "NoSearchableFieldsError",
15
+ "SearchConfig",
16
+ "SearchFieldType",
17
+ ]
@@ -12,13 +12,9 @@ from sqlalchemy.ext.asyncio import AsyncSession
12
12
  from sqlalchemy.orm import DeclarativeBase
13
13
  from sqlalchemy.sql.roles import WhereHavingRole
14
14
 
15
- from .db import get_transaction
16
- from .exceptions import NotFoundError
17
-
18
- __all__ = [
19
- "AsyncCrud",
20
- "CrudFactory",
21
- ]
15
+ from ..db import get_transaction
16
+ from ..exceptions import NotFoundError
17
+ from .search import SearchConfig, SearchFieldType, build_search_filters
22
18
 
23
19
  ModelType = TypeVar("ModelType", bound=DeclarativeBase)
24
20
 
@@ -27,20 +23,10 @@ class AsyncCrud(Generic[ModelType]):
27
23
  """Generic async CRUD operations for SQLAlchemy models.
28
24
 
29
25
  Subclass this and set the `model` class variable, or use `CrudFactory`.
30
-
31
- Example:
32
- class UserCrud(AsyncCrud[User]):
33
- model = User
34
-
35
- # Or use the factory:
36
- UserCrud = CrudFactory(User)
37
-
38
- # Then use it:
39
- user = await UserCrud.get(session, [User.id == 1])
40
- users = await UserCrud.get_multi(session, limit=10)
41
26
  """
42
27
 
43
28
  model: ClassVar[type[DeclarativeBase]]
29
+ searchable_fields: ClassVar[Sequence[SearchFieldType] | None] = None
44
30
 
45
31
  @classmethod
46
32
  async def create(
@@ -313,6 +299,8 @@ class AsyncCrud(Generic[ModelType]):
313
299
  order_by: Any | None = None,
314
300
  page: int = 1,
315
301
  items_per_page: int = 20,
302
+ search: str | SearchConfig | None = None,
303
+ search_fields: Sequence[SearchFieldType] | None = None,
316
304
  ) -> dict[str, Any]:
317
305
  """Get paginated results with metadata.
318
306
 
@@ -323,23 +311,54 @@ class AsyncCrud(Generic[ModelType]):
323
311
  order_by: Column or list of columns to order by
324
312
  page: Page number (1-indexed)
325
313
  items_per_page: Number of items per page
314
+ search: Search query string or SearchConfig object
315
+ search_fields: Fields to search in (overrides class default)
326
316
 
327
317
  Returns:
328
318
  Dict with 'data' and 'pagination' keys
329
319
  """
330
- filters = filters or []
320
+ filters = list(filters) if filters else []
331
321
  offset = (page - 1) * items_per_page
322
+ joins: list[Any] = []
323
+
324
+ # Build search filters
325
+ if search:
326
+ search_filters, search_joins = build_search_filters(
327
+ cls.model,
328
+ search,
329
+ search_fields=search_fields,
330
+ default_fields=cls.searchable_fields,
331
+ )
332
+ filters.extend(search_filters)
333
+ joins.extend(search_joins)
332
334
 
333
- items = await cls.get_multi(
334
- session,
335
- filters=filters,
336
- load_options=load_options,
337
- order_by=order_by,
338
- limit=items_per_page,
339
- offset=offset,
340
- )
335
+ # Build query with joins
336
+ q = select(cls.model)
337
+ for join_rel in joins:
338
+ q = q.outerjoin(join_rel)
341
339
 
342
- total_count = await cls.count(session, filters=filters)
340
+ if filters:
341
+ q = q.where(and_(*filters))
342
+ if load_options:
343
+ q = q.options(*load_options)
344
+ if order_by is not None:
345
+ q = q.order_by(order_by)
346
+
347
+ q = q.offset(offset).limit(items_per_page)
348
+ result = await session.execute(q)
349
+ items = result.unique().scalars().all()
350
+
351
+ # Count query (with same joins and filters)
352
+ pk_col = cls.model.__mapper__.primary_key[0]
353
+ count_q = select(func.count(func.distinct(getattr(cls.model, pk_col.name))))
354
+ count_q = count_q.select_from(cls.model)
355
+ for join_rel in joins:
356
+ count_q = count_q.outerjoin(join_rel)
357
+ if filters:
358
+ count_q = count_q.where(and_(*filters))
359
+
360
+ count_result = await session.execute(count_q)
361
+ total_count = count_result.scalar_one()
343
362
 
344
363
  return {
345
364
  "data": items,
@@ -354,11 +373,14 @@ class AsyncCrud(Generic[ModelType]):
354
373
 
355
374
  def CrudFactory(
356
375
  model: type[ModelType],
376
+ *,
377
+ searchable_fields: Sequence[SearchFieldType] | None = None,
357
378
  ) -> type[AsyncCrud[ModelType]]:
358
379
  """Create a CRUD class for a specific model.
359
380
 
360
381
  Args:
361
382
  model: SQLAlchemy model class
383
+ searchable_fields: Optional list of searchable fields
362
384
 
363
385
  Returns:
364
386
  AsyncCrud subclass bound to the model
@@ -370,9 +392,25 @@ def CrudFactory(
370
392
  UserCrud = CrudFactory(User)
371
393
  PostCrud = CrudFactory(Post)
372
394
 
395
+ # With searchable fields:
396
+ UserCrud = CrudFactory(
397
+ User,
398
+ searchable_fields=[User.username, User.email, (User.role, Role.name)]
399
+ )
400
+
373
401
  # Usage
374
402
  user = await UserCrud.get(session, [User.id == 1])
375
403
  posts = await PostCrud.get_multi(session, filters=[Post.user_id == user.id])
404
+
405
+ # With search
406
+ result = await UserCrud.paginate(session, search="john")
376
407
  """
377
- cls = type(f"Async{model.__name__}Crud", (AsyncCrud,), {"model": model})
408
+ cls = type(
409
+ f"Async{model.__name__}Crud",
410
+ (AsyncCrud,),
411
+ {
412
+ "model": model,
413
+ "searchable_fields": searchable_fields,
414
+ },
415
+ )
378
416
  return cast(type[AsyncCrud[ModelType]], cls)
@@ -0,0 +1,145 @@
1
+ """Search utilities for AsyncCrud."""
2
+
3
+ from collections.abc import Sequence
4
+ from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING, Any, Literal
6
+
7
+ from sqlalchemy import String, or_
8
+ from sqlalchemy.orm import DeclarativeBase
9
+ from sqlalchemy.orm.attributes import InstrumentedAttribute
10
+
11
+ from ..exceptions import NoSearchableFieldsError
12
+
13
+ if TYPE_CHECKING:
14
+ from sqlalchemy.sql.elements import ColumnElement
15
+
16
+ SearchFieldType = InstrumentedAttribute[Any] | tuple[InstrumentedAttribute[Any], ...]
17
+
18
+
19
+ @dataclass
20
+ class SearchConfig:
21
+ """Advanced search configuration.
22
+
23
+ Attributes:
24
+ query: The search string
25
+ fields: Fields to search (columns or tuples for relationships)
26
+ case_sensitive: Case-sensitive search (default: False)
27
+ match_mode: "any" (OR) or "all" (AND) to combine fields
28
+ """
29
+
30
+ query: str
31
+ fields: Sequence[SearchFieldType] | None = None
32
+ case_sensitive: bool = False
33
+ match_mode: Literal["any", "all"] = "any"
34
+
35
+
36
+ def get_searchable_fields(
37
+ model: type[DeclarativeBase],
38
+ *,
39
+ include_relationships: bool = True,
40
+ max_depth: int = 1,
41
+ ) -> list[SearchFieldType]:
42
+ """Auto-detect String fields on a model and its relationships.
43
+
44
+ Args:
45
+ model: SQLAlchemy model class
46
+ include_relationships: Include fields from many-to-one/one-to-one relationships
47
+ max_depth: Max depth for relationship traversal (default: 1)
48
+
49
+ Returns:
50
+ List of columns and tuples (relationship, column)
51
+ """
52
+ fields: list[SearchFieldType] = []
53
+ mapper = model.__mapper__
54
+
55
+ # Direct String columns
56
+ for col in mapper.columns:
57
+ if isinstance(col.type, String):
58
+ fields.append(getattr(model, col.key))
59
+
60
+ # Relationships (one-to-one, many-to-one only)
61
+ if include_relationships and max_depth > 0:
62
+ for rel_name, rel_prop in mapper.relationships.items():
63
+ if rel_prop.uselist: # Skip collections (one-to-many, many-to-many)
64
+ continue
65
+
66
+ rel_attr = getattr(model, rel_name)
67
+ related_model = rel_prop.mapper.class_
68
+
69
+ for col in related_model.__mapper__.columns:
70
+ if isinstance(col.type, String):
71
+ fields.append((rel_attr, getattr(related_model, col.key)))
72
+
73
+ return fields
74
+
75
+
76
+ def build_search_filters(
77
+ model: type[DeclarativeBase],
78
+ search: str | SearchConfig,
79
+ search_fields: Sequence[SearchFieldType] | None = None,
80
+ default_fields: Sequence[SearchFieldType] | None = None,
81
+ ) -> tuple[list["ColumnElement[bool]"], list[InstrumentedAttribute[Any]]]:
82
+ """Build SQLAlchemy filter conditions for search.
83
+
84
+ Args:
85
+ model: SQLAlchemy model class
86
+ search: Search string or SearchConfig
87
+ search_fields: Fields specified per-call (takes priority)
88
+ default_fields: Default fields (from ClassVar)
89
+
90
+ Returns:
91
+ Tuple of (filter_conditions, joins_needed)
92
+ """
93
+ # Normalize input
94
+ if isinstance(search, str):
95
+ config = SearchConfig(query=search, fields=search_fields)
96
+ else:
97
+ config = search
98
+ if search_fields is not None:
99
+ config = SearchConfig(
100
+ query=config.query,
101
+ fields=search_fields,
102
+ case_sensitive=config.case_sensitive,
103
+ match_mode=config.match_mode,
104
+ )
105
+
106
+ if not config.query or not config.query.strip():
107
+ return [], []
108
+
109
+ # Determine which fields to search
110
+ fields = config.fields or default_fields or get_searchable_fields(model)
111
+
112
+ if not fields:
113
+ raise NoSearchableFieldsError(model)
114
+
115
+ query = config.query.strip()
116
+ filters: list[ColumnElement[bool]] = []
117
+ joins: list[InstrumentedAttribute[Any]] = []
118
+ added_joins: set[str] = set()
119
+
120
+ for field in fields:
121
+ if isinstance(field, tuple):
122
+ # Relationship: (User.role, Role.name) or deeper
123
+ for rel in field[:-1]:
124
+ rel_key = str(rel)
125
+ if rel_key not in added_joins:
126
+ joins.append(rel)
127
+ added_joins.add(rel_key)
128
+ column = field[-1]
129
+ else:
130
+ column = field
131
+
132
+ # Build the filter
133
+ if config.case_sensitive:
134
+ filters.append(column.like(f"%{query}%"))
135
+ else:
136
+ filters.append(column.ilike(f"%{query}%"))
137
+
138
+ if not filters:
139
+ return [], []
140
+
141
+ # Combine based on match_mode
142
+ if config.match_mode == "any":
143
+ return [or_(*filters)], joins
144
+ else:
145
+ return filters, joins
@@ -2,6 +2,7 @@ from .exceptions import (
2
2
  ApiException,
3
3
  ConflictError,
4
4
  ForbiddenError,
5
+ NoSearchableFieldsError,
5
6
  NotFoundError,
6
7
  UnauthorizedError,
7
8
  generate_error_responses,
@@ -14,6 +15,7 @@ __all__ = [
14
15
  "ApiException",
15
16
  "ConflictError",
16
17
  "ForbiddenError",
18
+ "NoSearchableFieldsError",
17
19
  "NotFoundError",
18
20
  "UnauthorizedError",
19
21
  ]
@@ -119,6 +119,25 @@ class RoleNotFoundError(NotFoundError):
119
119
  )
120
120
 
121
121
 
122
+ class NoSearchableFieldsError(ApiException):
123
+ """Raised when search is requested but no searchable fields are available."""
124
+
125
+ api_error = ApiError(
126
+ code=400,
127
+ msg="No Searchable Fields",
128
+ desc="No searchable fields configured for this resource.",
129
+ err_code="SEARCH-400",
130
+ )
131
+
132
+ def __init__(self, model: type) -> None:
133
+ self.model = model
134
+ detail = (
135
+ f"No searchable fields found for model '{model.__name__}'. "
136
+ "Provide 'search_fields' parameter or set 'searchable_fields' on the CRUD class."
137
+ )
138
+ super().__init__(detail)
139
+
140
+
122
141
  def generate_error_responses(
123
142
  *errors: type[ApiException],
124
143
  ) -> dict[int | str, dict[str, Any]]:
@@ -1,11 +1,6 @@
1
- from .fixtures import (
2
- Context,
3
- FixtureRegistry,
4
- LoadStrategy,
5
- load_fixtures,
6
- load_fixtures_by_context,
7
- )
8
- from .utils import get_obj_by_attr
1
+ from .enum import LoadStrategy
2
+ from .registry import Context, FixtureRegistry
3
+ from .utils import get_obj_by_attr, load_fixtures, load_fixtures_by_context
9
4
 
10
5
  __all__ = [
11
6
  "Context",
@@ -16,12 +11,3 @@ __all__ = [
16
11
  "load_fixtures_by_context",
17
12
  "register_fixtures",
18
13
  ]
19
-
20
-
21
- # We lazy-load register_fixtures to avoid needing pytest when using fixtures CLI
22
- def __getattr__(name: str):
23
- if name == "register_fixtures":
24
- from .pytest_plugin import register_fixtures
25
-
26
- return register_fixtures
27
- raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
@@ -0,0 +1,30 @@
1
+ from enum import Enum
2
+
3
+
4
+ class LoadStrategy(str, Enum):
5
+ """Strategy for loading fixtures into the database."""
6
+
7
+ INSERT = "insert"
8
+ """Insert new records. Fails if record already exists."""
9
+
10
+ MERGE = "merge"
11
+ """Insert or update based on primary key (SQLAlchemy merge)."""
12
+
13
+ SKIP_EXISTING = "skip_existing"
14
+ """Insert only if record doesn't exist (based on primary key)."""
15
+
16
+
17
+ class Context(str, Enum):
18
+ """Predefined fixture contexts."""
19
+
20
+ BASE = "base"
21
+ """Base fixtures loaded in all environments."""
22
+
23
+ PRODUCTION = "production"
24
+ """Production-only fixtures."""
25
+
26
+ DEVELOPMENT = "development"
27
+ """Development fixtures."""
28
+
29
+ TESTING = "testing"
30
+ """Test fixtures."""
@@ -3,46 +3,15 @@
3
3
  import logging
4
4
  from collections.abc import Callable, Sequence
5
5
  from dataclasses import dataclass, field
6
- from enum import Enum
7
6
  from typing import Any, cast
8
7
 
9
- from sqlalchemy.ext.asyncio import AsyncSession
10
8
  from sqlalchemy.orm import DeclarativeBase
11
9
 
12
- from ..db import get_transaction
10
+ from .enum import Context
13
11
 
14
12
  logger = logging.getLogger(__name__)
15
13
 
16
14
 
17
- class LoadStrategy(str, Enum):
18
- """Strategy for loading fixtures into the database."""
19
-
20
- INSERT = "insert"
21
- """Insert new records. Fails if record already exists."""
22
-
23
- MERGE = "merge"
24
- """Insert or update based on primary key (SQLAlchemy merge)."""
25
-
26
- SKIP_EXISTING = "skip_existing"
27
- """Insert only if record doesn't exist (based on primary key)."""
28
-
29
-
30
- class Context(str, Enum):
31
- """Predefined fixture contexts."""
32
-
33
- BASE = "base"
34
- """Base fixtures loaded in all environments."""
35
-
36
- PRODUCTION = "production"
37
- """Production-only fixtures."""
38
-
39
- DEVELOPMENT = "development"
40
- """Development fixtures."""
41
-
42
- TESTING = "testing"
43
- """Test fixtures."""
44
-
45
-
46
15
  @dataclass
47
16
  class Fixture:
48
17
  """A fixture definition with metadata."""
@@ -204,118 +173,3 @@ class FixtureRegistry:
204
173
  all_deps.update(deps)
205
174
 
206
175
  return self.resolve_dependencies(*all_deps)
207
-
208
-
209
- async def load_fixtures(
210
- session: AsyncSession,
211
- registry: FixtureRegistry,
212
- *names: str,
213
- strategy: LoadStrategy = LoadStrategy.MERGE,
214
- ) -> dict[str, list[DeclarativeBase]]:
215
- """Load specific fixtures by name with dependencies.
216
-
217
- Args:
218
- session: Database session
219
- registry: Fixture registry
220
- *names: Fixture names to load (dependencies auto-resolved)
221
- strategy: How to handle existing records
222
-
223
- Returns:
224
- Dict mapping fixture names to loaded instances
225
-
226
- Example:
227
- # Loads 'roles' first (dependency), then 'users'
228
- result = await load_fixtures(session, fixtures, "users")
229
- print(result["users"]) # [User(...), ...]
230
- """
231
- ordered = registry.resolve_dependencies(*names)
232
- return await _load_ordered(session, registry, ordered, strategy)
233
-
234
-
235
- async def load_fixtures_by_context(
236
- session: AsyncSession,
237
- registry: FixtureRegistry,
238
- *contexts: str | Context,
239
- strategy: LoadStrategy = LoadStrategy.MERGE,
240
- ) -> dict[str, list[DeclarativeBase]]:
241
- """Load all fixtures for specific contexts.
242
-
243
- Args:
244
- session: Database session
245
- registry: Fixture registry
246
- *contexts: Contexts to load (e.g., Context.BASE, Context.TESTING)
247
- strategy: How to handle existing records
248
-
249
- Returns:
250
- Dict mapping fixture names to loaded instances
251
-
252
- Example:
253
- # Load base + testing fixtures
254
- await load_fixtures_by_context(
255
- session, fixtures,
256
- Context.BASE, Context.TESTING
257
- )
258
- """
259
- ordered = registry.resolve_context_dependencies(*contexts)
260
- return await _load_ordered(session, registry, ordered, strategy)
261
-
262
-
263
- async def _load_ordered(
264
- session: AsyncSession,
265
- registry: FixtureRegistry,
266
- ordered_names: list[str],
267
- strategy: LoadStrategy,
268
- ) -> dict[str, list[DeclarativeBase]]:
269
- """Load fixtures in order."""
270
- results: dict[str, list[DeclarativeBase]] = {}
271
-
272
- for name in ordered_names:
273
- fixture = registry.get(name)
274
- instances = list(fixture.func())
275
-
276
- if not instances:
277
- results[name] = []
278
- continue
279
-
280
- model_name = type(instances[0]).__name__
281
- loaded: list[DeclarativeBase] = []
282
-
283
- async with get_transaction(session):
284
- for instance in instances:
285
- if strategy == LoadStrategy.INSERT:
286
- session.add(instance)
287
- loaded.append(instance)
288
-
289
- elif strategy == LoadStrategy.MERGE:
290
- merged = await session.merge(instance)
291
- loaded.append(merged)
292
-
293
- elif strategy == LoadStrategy.SKIP_EXISTING:
294
- pk = _get_primary_key(instance)
295
- if pk is not None:
296
- existing = await session.get(type(instance), pk)
297
- if existing is None:
298
- session.add(instance)
299
- loaded.append(instance)
300
- else:
301
- session.add(instance)
302
- loaded.append(instance)
303
-
304
- results[name] = loaded
305
- logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
306
-
307
- return results
308
-
309
-
310
- def _get_primary_key(instance: DeclarativeBase) -> Any | None:
311
- """Get the primary key value of a model instance."""
312
- mapper = instance.__class__.__mapper__
313
- pk_cols = mapper.primary_key
314
-
315
- if len(pk_cols) == 1:
316
- return getattr(instance, pk_cols[0].name, None)
317
-
318
- pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
319
- if all(v is not None for v in pk_values):
320
- return pk_values
321
- return None
@@ -1,8 +1,16 @@
1
+ import logging
1
2
  from collections.abc import Callable, Sequence
2
3
  from typing import Any, TypeVar
3
4
 
5
+ from sqlalchemy.ext.asyncio import AsyncSession
4
6
  from sqlalchemy.orm import DeclarativeBase
5
7
 
8
+ from ..db import get_transaction
9
+ from .enum import LoadStrategy
10
+ from .registry import Context, FixtureRegistry
11
+
12
+ logger = logging.getLogger(__name__)
13
+
6
14
  T = TypeVar("T", bound=DeclarativeBase)
7
15
 
8
16
 
@@ -24,3 +32,118 @@ def get_obj_by_attr(
24
32
  StopIteration: If no matching object is found.
25
33
  """
26
34
  return next(obj for obj in fixtures() if getattr(obj, attr_name) == value)
35
+
36
+
37
+ async def load_fixtures(
38
+ session: AsyncSession,
39
+ registry: FixtureRegistry,
40
+ *names: str,
41
+ strategy: LoadStrategy = LoadStrategy.MERGE,
42
+ ) -> dict[str, list[DeclarativeBase]]:
43
+ """Load specific fixtures by name with dependencies.
44
+
45
+ Args:
46
+ session: Database session
47
+ registry: Fixture registry
48
+ *names: Fixture names to load (dependencies auto-resolved)
49
+ strategy: How to handle existing records
50
+
51
+ Returns:
52
+ Dict mapping fixture names to loaded instances
53
+
54
+ Example:
55
+ # Loads 'roles' first (dependency), then 'users'
56
+ result = await load_fixtures(session, fixtures, "users")
57
+ print(result["users"]) # [User(...), ...]
58
+ """
59
+ ordered = registry.resolve_dependencies(*names)
60
+ return await _load_ordered(session, registry, ordered, strategy)
61
+
62
+
63
+ async def load_fixtures_by_context(
64
+ session: AsyncSession,
65
+ registry: FixtureRegistry,
66
+ *contexts: str | Context,
67
+ strategy: LoadStrategy = LoadStrategy.MERGE,
68
+ ) -> dict[str, list[DeclarativeBase]]:
69
+ """Load all fixtures for specific contexts.
70
+
71
+ Args:
72
+ session: Database session
73
+ registry: Fixture registry
74
+ *contexts: Contexts to load (e.g., Context.BASE, Context.TESTING)
75
+ strategy: How to handle existing records
76
+
77
+ Returns:
78
+ Dict mapping fixture names to loaded instances
79
+
80
+ Example:
81
+ # Load base + testing fixtures
82
+ await load_fixtures_by_context(
83
+ session, fixtures,
84
+ Context.BASE, Context.TESTING
85
+ )
86
+ """
87
+ ordered = registry.resolve_context_dependencies(*contexts)
88
+ return await _load_ordered(session, registry, ordered, strategy)
89
+
90
+
91
+ async def _load_ordered(
92
+ session: AsyncSession,
93
+ registry: FixtureRegistry,
94
+ ordered_names: list[str],
95
+ strategy: LoadStrategy,
96
+ ) -> dict[str, list[DeclarativeBase]]:
97
+ """Load fixtures in order."""
98
+ results: dict[str, list[DeclarativeBase]] = {}
99
+
100
+ for name in ordered_names:
101
+ fixture = registry.get(name)
102
+ instances = list(fixture.func())
103
+
104
+ if not instances:
105
+ results[name] = []
106
+ continue
107
+
108
+ model_name = type(instances[0]).__name__
109
+ loaded: list[DeclarativeBase] = []
110
+
111
+ async with get_transaction(session):
112
+ for instance in instances:
113
+ if strategy == LoadStrategy.INSERT:
114
+ session.add(instance)
115
+ loaded.append(instance)
116
+
117
+ elif strategy == LoadStrategy.MERGE:
118
+ merged = await session.merge(instance)
119
+ loaded.append(merged)
120
+
121
+ elif strategy == LoadStrategy.SKIP_EXISTING:
122
+ pk = _get_primary_key(instance)
123
+ if pk is not None:
124
+ existing = await session.get(type(instance), pk)
125
+ if existing is None:
126
+ session.add(instance)
127
+ loaded.append(instance)
128
+ else:
129
+ session.add(instance)
130
+ loaded.append(instance)
131
+
132
+ results[name] = loaded
133
+ logger.info(f"Loaded fixture '{name}': {len(loaded)} {model_name}(s)")
134
+
135
+ return results
136
+
137
+
138
+ def _get_primary_key(instance: DeclarativeBase) -> Any | None:
139
+ """Get the primary key value of a model instance."""
140
+ mapper = instance.__class__.__mapper__
141
+ pk_cols = mapper.primary_key
142
+
143
+ if len(pk_cols) == 1:
144
+ return getattr(instance, pk_cols[0].name, None)
145
+
146
+ pk_values = tuple(getattr(instance, col.name, None) for col in pk_cols)
147
+ if all(v is not None for v in pk_values):
148
+ return pk_values
149
+ return None
@@ -0,0 +1,8 @@
1
+ from .plugin import register_fixtures
2
+ from .utils import create_async_client, create_db_session
3
+
4
+ __all__ = [
5
+ "create_async_client",
6
+ "create_db_session",
7
+ "register_fixtures",
8
+ ]
@@ -59,7 +59,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
59
59
  from sqlalchemy.orm import DeclarativeBase
60
60
 
61
61
  from ..db import get_transaction
62
- from .fixtures import FixtureRegistry, LoadStrategy
62
+ from ..fixtures import FixtureRegistry, LoadStrategy
63
63
 
64
64
 
65
65
  def register_fixtures(
@@ -0,0 +1,110 @@
1
+ """Pytest helper utilities for FastAPI testing."""
2
+
3
+ from collections.abc import AsyncGenerator
4
+ from contextlib import asynccontextmanager
5
+ from typing import Any
6
+
7
+ from httpx import ASGITransport, AsyncClient
8
+ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
9
+ from sqlalchemy.orm import DeclarativeBase
10
+
11
+ from ..db import create_db_context
12
+
13
+
14
+ @asynccontextmanager
15
+ async def create_async_client(
16
+ app: Any,
17
+ base_url: str = "http://test",
18
+ ) -> AsyncGenerator[AsyncClient, None]:
19
+ """Create an async httpx client for testing FastAPI applications.
20
+
21
+ Args:
22
+ app: FastAPI application instance.
23
+ base_url: Base URL for requests. Defaults to "http://test".
24
+
25
+ Yields:
26
+ An AsyncClient configured for the app.
27
+
28
+ Example:
29
+ ```python
30
+ from fastapi import FastAPI
31
+ from fastapi_toolsets.pytest import create_async_client
32
+
33
+ app = FastAPI()
34
+
35
+ @pytest.fixture
36
+ async def client():
37
+ async with create_async_client(app) as c:
38
+ yield c
39
+
40
+ async def test_endpoint(client: AsyncClient):
41
+ response = await client.get("/health")
42
+ assert response.status_code == 200
43
+ ```
44
+ """
45
+ transport = ASGITransport(app=app)
46
+ async with AsyncClient(transport=transport, base_url=base_url) as client:
47
+ yield client
48
+
49
+
50
+ @asynccontextmanager
51
+ async def create_db_session(
52
+ database_url: str,
53
+ base: type[DeclarativeBase],
54
+ *,
55
+ echo: bool = False,
56
+ expire_on_commit: bool = False,
57
+ drop_tables: bool = True,
58
+ ) -> AsyncGenerator[AsyncSession, None]:
59
+ """Create a database session for testing.
60
+
61
+ Creates tables before yielding the session and optionally drops them after.
62
+ Each call creates a fresh engine and session for test isolation.
63
+
64
+ Args:
65
+ database_url: Database connection URL (e.g., "postgresql+asyncpg://...").
66
+ base: SQLAlchemy DeclarativeBase class containing model metadata.
67
+ echo: Enable SQLAlchemy query logging. Defaults to False.
68
+ expire_on_commit: Expire objects after commit. Defaults to False.
69
+ drop_tables: Drop tables after test. Defaults to True.
70
+
71
+ Yields:
72
+ An AsyncSession ready for database operations.
73
+
74
+ Example:
75
+ ```python
76
+ from fastapi_toolsets.pytest import create_db_session
77
+ from app.models import Base
78
+
79
+ DATABASE_URL = "postgresql+asyncpg://user:pass@localhost/test_db"
80
+
81
+ @pytest.fixture
82
+ async def db_session():
83
+ async with create_db_session(DATABASE_URL, Base) as session:
84
+ yield session
85
+
86
+ async def test_create_user(db_session: AsyncSession):
87
+ user = User(name="test")
88
+ db_session.add(user)
89
+ await db_session.commit()
90
+ ```
91
+ """
92
+ engine = create_async_engine(database_url, echo=echo)
93
+
94
+ try:
95
+ # Create tables
96
+ async with engine.begin() as conn:
97
+ await conn.run_sync(base.metadata.create_all)
98
+
99
+ # Create session using existing db context utility
100
+ session_maker = async_sessionmaker(engine, expire_on_commit=expire_on_commit)
101
+ get_session = create_db_context(session_maker)
102
+
103
+ async with get_session() as session:
104
+ yield session
105
+
106
+ if drop_tables:
107
+ async with engine.begin() as conn:
108
+ await conn.run_sync(base.metadata.drop_all)
109
+ finally:
110
+ await engine.dispose()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fastapi-toolsets
3
- Version: 0.2.0
3
+ Version: 0.4.0
4
4
  Summary: Reusable tools for FastAPI: async CRUD, fixtures, CLI, and standardized responses for SQLAlchemy + PostgreSQL
5
5
  Keywords: fastapi,sqlalchemy,postgresql
6
6
  Author: d3vyce
@@ -20,6 +20,7 @@ Classifier: Programming Language :: Python :: 3 :: Only
20
20
  Classifier: Programming Language :: Python :: 3.11
21
21
  Classifier: Programming Language :: Python :: 3.12
22
22
  Classifier: Programming Language :: Python :: 3.13
23
+ Classifier: Programming Language :: Python :: 3.14
23
24
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
24
25
  Classifier: Topic :: Software Development :: Libraries
25
26
  Classifier: Topic :: Software Development
@@ -0,0 +1,26 @@
1
+ fastapi_toolsets/__init__.py,sha256=gzOvHPBZ5Kdja8XWMQM-zTxHKPJ7M7ptud7tUsd9bws,820
2
+ fastapi_toolsets/cli/__init__.py,sha256=QAcrenphE7D5toid_Kn777Cy1icSOWiEBjSE_oCuU4o,111
3
+ fastapi_toolsets/cli/app.py,sha256=G3y3PNN3biHs8GtiiGsrX2X8nXNjkrsUhHLXQo2-MXc,2588
4
+ fastapi_toolsets/cli/commands/__init__.py,sha256=BogehmsY6olwLdfIBzviuppXP1LLl9znnxtmji3eLwI,29
5
+ fastapi_toolsets/cli/commands/fixtures.py,sha256=qiC2dcrJ_Rb1PRzx6EycTFQQXwa_GQUVoptpx4chf9k,6679
6
+ fastapi_toolsets/crud/__init__.py,sha256=LRR57jsFna2IqSj6aTomMzFcaQbx7WVo5lnfhcTSU1g,369
7
+ fastapi_toolsets/crud/factory.py,sha256=c5aH2c38Qa44pQ3XKoJTcRox5x4rD4tuEJQ5DKFjSnc,13014
8
+ fastapi_toolsets/crud/search.py,sha256=a46SGMg484uv2DruNfEB-Rq6lJRi_K-Nr8tFTWQ2ouY,4530
9
+ fastapi_toolsets/db.py,sha256=YUj5CrxCnREg7AqpJLNrLR2RDIOCS7stQCNOSS3cRho,5619
10
+ fastapi_toolsets/exceptions/__init__.py,sha256=wlV4pVXuGdOtUvlThRJXmEc8g8Nmwt8MOMG_A-3j9zw,451
11
+ fastapi_toolsets/exceptions/exceptions.py,sha256=hu8_lvE9KmnYID9YgJqlzZMkCD0kASPrGAmN1hUe2bY,5086
12
+ fastapi_toolsets/exceptions/handler.py,sha256=IXfKiIr_LPo-11PRpOIrNRAXBkeQ5TdLcu3Gy-r6ChU,5916
13
+ fastapi_toolsets/fixtures/__init__.py,sha256=i5N6dt4LLVxhC0fBNhDTokuqUf43oXJChBkVMQ94hLA,328
14
+ fastapi_toolsets/fixtures/enum.py,sha256=02T4CrkH3-A3mPxpHaLzBQD4yzqExwjycaDBJr1ameA,715
15
+ fastapi_toolsets/fixtures/registry.py,sha256=lfoLdC6aZeJQR7_l0g_P5Y6DChbs8_zAgLQsR0Plmfg,5276
16
+ fastapi_toolsets/fixtures/utils.py,sha256=DlsGBVl0zyxtKX5E3CEcV7rhzBYng6KD5hELD7bK0wo,4711
17
+ fastapi_toolsets/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
+ fastapi_toolsets/pytest/__init__.py,sha256=0GnnFxWNfpaShBBYsfRGIgSygwC1eo8TydmXCA3tmoM,188
19
+ fastapi_toolsets/pytest/plugin.py,sha256=lbEiumS2zi7jARY6eYBUPAlfKCplbLFrZXcmp0-RkcA,6892
20
+ fastapi_toolsets/pytest/utils.py,sha256=VqkxtbpEU8w7-0xfcZG0m8Tpn3LtdnvAJMyqWS7WtIw,3447
21
+ fastapi_toolsets/schemas.py,sha256=LBzrq4s5VWYeQqlUfOEvWDtpFdO8scgY0LRypk9KUAE,2639
22
+ fastapi_toolsets-0.4.0.dist-info/licenses/LICENSE,sha256=V2jCjI-VPB-veGY2Ktb0sU4vT_TldRciZ9lCE98bMoE,1063
23
+ fastapi_toolsets-0.4.0.dist-info/WHEEL,sha256=e_m4S054HL0hyR3CpOk-b7Q7fDX6BuFkgL5OjAExXas,80
24
+ fastapi_toolsets-0.4.0.dist-info/entry_points.txt,sha256=pNU38Nn_DXBgYd-nLZCizMvrrdaPhHmkRwouDoBqvzw,63
25
+ fastapi_toolsets-0.4.0.dist-info/METADATA,sha256=cRTaYwgiAEVdU_DD4lG7idF6agtcxgSHpFsFIwp-7HY,4221
26
+ fastapi_toolsets-0.4.0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: uv 0.9.26
2
+ Generator: uv 0.9.27
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,21 +0,0 @@
1
- fastapi_toolsets/__init__.py,sha256=UxUjB9wMH-YmCzQimjgZeWuhEgtu3BKzimlEn-Atg0Q,820
2
- fastapi_toolsets/cli/__init__.py,sha256=QAcrenphE7D5toid_Kn777Cy1icSOWiEBjSE_oCuU4o,111
3
- fastapi_toolsets/cli/app.py,sha256=G3y3PNN3biHs8GtiiGsrX2X8nXNjkrsUhHLXQo2-MXc,2588
4
- fastapi_toolsets/cli/commands/__init__.py,sha256=BogehmsY6olwLdfIBzviuppXP1LLl9znnxtmji3eLwI,29
5
- fastapi_toolsets/cli/commands/fixtures.py,sha256=qiC2dcrJ_Rb1PRzx6EycTFQQXwa_GQUVoptpx4chf9k,6679
6
- fastapi_toolsets/crud.py,sha256=pD26V91y0Z5f9ft0I6ggY1EJ-Oh-QATgANgzEyj4EHU,11370
7
- fastapi_toolsets/db.py,sha256=YUj5CrxCnREg7AqpJLNrLR2RDIOCS7stQCNOSS3cRho,5619
8
- fastapi_toolsets/exceptions/__init__.py,sha256=PDiTg4NpEUhGi5p9Sfvn1FtxdlD9-Y4OPhWCQz6fW9g,391
9
- fastapi_toolsets/exceptions/exceptions.py,sha256=O6fqDbPfEFSj1oi_vo3-gnES7RPpSf-l_yXh4fTCBjg,4470
10
- fastapi_toolsets/exceptions/handler.py,sha256=IXfKiIr_LPo-11PRpOIrNRAXBkeQ5TdLcu3Gy-r6ChU,5916
11
- fastapi_toolsets/fixtures/__init__.py,sha256=1oDN4Of3utEyyrv4iFee4TWf1symsFsGUsmOP2RqLgc,645
12
- fastapi_toolsets/fixtures/fixtures.py,sha256=RB073yojMQoRCxP_-pvZWSW_viPQsD3RRJ5NhEZ2wZQ,9720
13
- fastapi_toolsets/fixtures/pytest_plugin.py,sha256=J2PjasekQUxzurk7_MVncQRcIZ1AIawsO6OlqKqKYT8,6891
14
- fastapi_toolsets/fixtures/utils.py,sha256=RQnm5eFPNFQxms_QnBDraglsXj67H6VwCnn2OPnyEUQ,824
15
- fastapi_toolsets/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
- fastapi_toolsets/schemas.py,sha256=LBzrq4s5VWYeQqlUfOEvWDtpFdO8scgY0LRypk9KUAE,2639
17
- fastapi_toolsets-0.2.0.dist-info/licenses/LICENSE,sha256=V2jCjI-VPB-veGY2Ktb0sU4vT_TldRciZ9lCE98bMoE,1063
18
- fastapi_toolsets-0.2.0.dist-info/WHEEL,sha256=XV0cjMrO7zXhVAIyyc8aFf1VjZ33Fen4IiJk5zFlC3g,80
19
- fastapi_toolsets-0.2.0.dist-info/entry_points.txt,sha256=pNU38Nn_DXBgYd-nLZCizMvrrdaPhHmkRwouDoBqvzw,63
20
- fastapi_toolsets-0.2.0.dist-info/METADATA,sha256=ZgmtHOB0t0BqNU4CeFjPYVlwfwPGJ4zUQDm9k1Q0Ipc,4170
21
- fastapi_toolsets-0.2.0.dist-info/RECORD,,