fastapi-toolsets 2.0.0__tar.gz → 2.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/PKG-INFO +1 -1
  2. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/pyproject.toml +1 -1
  3. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/__init__.py +1 -1
  4. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/crud/factory.py +131 -12
  5. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/db.py +69 -2
  6. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/metrics/handler.py +1 -1
  7. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/metrics/registry.py +21 -45
  8. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/pytest/utils.py +129 -159
  9. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/LICENSE +0 -0
  10. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/README.md +0 -0
  11. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/_imports.py +0 -0
  12. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/cli/__init__.py +0 -0
  13. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/cli/app.py +0 -0
  14. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/cli/commands/__init__.py +0 -0
  15. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/cli/commands/fixtures.py +0 -0
  16. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/cli/config.py +0 -0
  17. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/cli/pyproject.py +0 -0
  18. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/cli/utils.py +0 -0
  19. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/crud/__init__.py +0 -0
  20. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/crud/search.py +0 -0
  21. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/dependencies.py +0 -0
  22. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/exceptions/__init__.py +0 -0
  23. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/exceptions/exceptions.py +0 -0
  24. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/exceptions/handler.py +0 -0
  25. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/fixtures/__init__.py +0 -0
  26. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/fixtures/enum.py +0 -0
  27. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/fixtures/registry.py +0 -0
  28. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/fixtures/utils.py +0 -0
  29. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/logger.py +0 -0
  30. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/metrics/__init__.py +0 -0
  31. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/models.py +0 -0
  32. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/py.typed +0 -0
  33. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/pytest/__init__.py +0 -0
  34. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/pytest/plugin.py +0 -0
  35. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/schemas.py +0 -0
  36. {fastapi_toolsets-2.0.0 → fastapi_toolsets-2.2.0}/src/fastapi_toolsets/types.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: fastapi-toolsets
3
- Version: 2.0.0
3
+ Version: 2.2.0
4
4
  Summary: Production-ready utilities for FastAPI applications
5
5
  Keywords: fastapi,sqlalchemy,postgresql
6
6
  Author: d3vyce
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "fastapi-toolsets"
3
- version = "2.0.0"
3
+ version = "2.2.0"
4
4
  description = "Production-ready utilities for FastAPI applications"
5
5
  readme = "README.md"
6
6
  license = "MIT"
@@ -21,4 +21,4 @@ Example usage:
21
21
  return Response(data={"user": user.username}, message="Success")
22
22
  """
23
23
 
24
- __version__ = "2.0.0"
24
+ __version__ = "2.2.0"
@@ -14,7 +14,6 @@ from typing import Any, ClassVar, Generic, Literal, Self, cast, overload
14
14
  from fastapi import Query
15
15
  from pydantic import BaseModel
16
16
  from sqlalchemy import Date, DateTime, Float, Integer, Numeric, Uuid, and_, func, select
17
- from sqlalchemy import delete as sql_delete
18
17
  from sqlalchemy.dialects.postgresql import insert
19
18
  from sqlalchemy.exc import NoResultFound
20
19
  from sqlalchemy.ext.asyncio import AsyncSession
@@ -410,6 +409,82 @@ class AsyncCrud(Generic[ModelType]):
410
409
  NotFoundError: If no record found
411
410
  MultipleResultsFound: If more than one record found
412
411
  """
412
+ result = await cls.get_or_none(
413
+ session,
414
+ filters,
415
+ joins=joins,
416
+ outer_join=outer_join,
417
+ with_for_update=with_for_update,
418
+ load_options=load_options,
419
+ schema=schema,
420
+ )
421
+ if result is None:
422
+ raise NotFoundError()
423
+ return result
424
+
425
+ @overload
426
+ @classmethod
427
+ async def get_or_none( # pragma: no cover
428
+ cls: type[Self],
429
+ session: AsyncSession,
430
+ filters: list[Any],
431
+ *,
432
+ joins: JoinType | None = None,
433
+ outer_join: bool = False,
434
+ with_for_update: bool = False,
435
+ load_options: list[ExecutableOption] | None = None,
436
+ schema: type[SchemaType],
437
+ ) -> Response[SchemaType] | None: ...
438
+
439
+ @overload
440
+ @classmethod
441
+ async def get_or_none( # pragma: no cover
442
+ cls: type[Self],
443
+ session: AsyncSession,
444
+ filters: list[Any],
445
+ *,
446
+ joins: JoinType | None = None,
447
+ outer_join: bool = False,
448
+ with_for_update: bool = False,
449
+ load_options: list[ExecutableOption] | None = None,
450
+ schema: None = ...,
451
+ ) -> ModelType | None: ...
452
+
453
+ @classmethod
454
+ async def get_or_none(
455
+ cls: type[Self],
456
+ session: AsyncSession,
457
+ filters: list[Any],
458
+ *,
459
+ joins: JoinType | None = None,
460
+ outer_join: bool = False,
461
+ with_for_update: bool = False,
462
+ load_options: list[ExecutableOption] | None = None,
463
+ schema: type[BaseModel] | None = None,
464
+ ) -> ModelType | Response[Any] | None:
465
+ """Get exactly one record, or ``None`` if not found.
466
+
467
+ Like :meth:`get` but returns ``None`` instead of raising
468
+ :class:`~fastapi_toolsets.exceptions.NotFoundError` when no record
469
+ matches the filters.
470
+
471
+ Args:
472
+ session: DB async session
473
+ filters: List of SQLAlchemy filter conditions
474
+ joins: List of (model, condition) tuples for joining related tables
475
+ outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
476
+ with_for_update: Lock the row for update
477
+ load_options: SQLAlchemy loader options (e.g., selectinload)
478
+ schema: Pydantic schema to serialize the result into. When provided,
479
+ the result is automatically wrapped in a ``Response[schema]``.
480
+
481
+ Returns:
482
+ Model instance, ``Response[schema]`` when ``schema`` is given,
483
+ or ``None`` when no record matches.
484
+
485
+ Raises:
486
+ MultipleResultsFound: If more than one record found
487
+ """
413
488
  q = select(cls.model)
414
489
  q = _apply_joins(q, joins, outer_join)
415
490
  q = q.where(and_(*filters))
@@ -419,12 +494,40 @@ class AsyncCrud(Generic[ModelType]):
419
494
  q = q.with_for_update()
420
495
  result = await session.execute(q)
421
496
  item = result.unique().scalar_one_or_none()
422
- if not item:
423
- raise NotFoundError()
424
- result = cast(ModelType, item)
497
+ if item is None:
498
+ return None
499
+ db_model = cast(ModelType, item)
425
500
  if schema:
426
- return Response(data=schema.model_validate(result))
427
- return result
501
+ return Response(data=schema.model_validate(db_model))
502
+ return db_model
503
+
504
+ @overload
505
+ @classmethod
506
+ async def first( # pragma: no cover
507
+ cls: type[Self],
508
+ session: AsyncSession,
509
+ filters: list[Any] | None = None,
510
+ *,
511
+ joins: JoinType | None = None,
512
+ outer_join: bool = False,
513
+ with_for_update: bool = False,
514
+ load_options: list[ExecutableOption] | None = None,
515
+ schema: type[SchemaType],
516
+ ) -> Response[SchemaType] | None: ...
517
+
518
+ @overload
519
+ @classmethod
520
+ async def first( # pragma: no cover
521
+ cls: type[Self],
522
+ session: AsyncSession,
523
+ filters: list[Any] | None = None,
524
+ *,
525
+ joins: JoinType | None = None,
526
+ outer_join: bool = False,
527
+ with_for_update: bool = False,
528
+ load_options: list[ExecutableOption] | None = None,
529
+ schema: None = ...,
530
+ ) -> ModelType | None: ...
428
531
 
429
532
  @classmethod
430
533
  async def first(
@@ -434,8 +537,10 @@ class AsyncCrud(Generic[ModelType]):
434
537
  *,
435
538
  joins: JoinType | None = None,
436
539
  outer_join: bool = False,
540
+ with_for_update: bool = False,
437
541
  load_options: list[ExecutableOption] | None = None,
438
- ) -> ModelType | None:
542
+ schema: type[BaseModel] | None = None,
543
+ ) -> ModelType | Response[Any] | None:
439
544
  """Get the first matching record, or None.
440
545
 
441
546
  Args:
@@ -443,10 +548,14 @@ class AsyncCrud(Generic[ModelType]):
443
548
  filters: List of SQLAlchemy filter conditions
444
549
  joins: List of (model, condition) tuples for joining related tables
445
550
  outer_join: Use LEFT OUTER JOIN instead of INNER JOIN
446
- load_options: SQLAlchemy loader options
551
+ with_for_update: Lock the row for update
552
+ load_options: SQLAlchemy loader options (e.g., selectinload)
553
+ schema: Pydantic schema to serialize the result into. When provided,
554
+ the result is automatically wrapped in a ``Response[schema]``.
447
555
 
448
556
  Returns:
449
- Model instance or None
557
+ Model instance, ``Response[schema]`` when ``schema`` is given,
558
+ or ``None`` when no record matches.
450
559
  """
451
560
  q = select(cls.model)
452
561
  q = _apply_joins(q, joins, outer_join)
@@ -454,8 +563,16 @@ class AsyncCrud(Generic[ModelType]):
454
563
  q = q.where(and_(*filters))
455
564
  if resolved := cls._resolve_load_options(load_options):
456
565
  q = q.options(*resolved)
566
+ if with_for_update:
567
+ q = q.with_for_update()
457
568
  result = await session.execute(q)
458
- return cast(ModelType | None, result.unique().scalars().first())
569
+ item = result.unique().scalars().first()
570
+ if item is None:
571
+ return None
572
+ db_model = cast(ModelType, item)
573
+ if schema:
574
+ return Response(data=schema.model_validate(db_model))
575
+ return db_model
459
576
 
460
577
  @classmethod
461
578
  async def get_multi(
@@ -674,8 +791,10 @@ class AsyncCrud(Generic[ModelType]):
674
791
  ``None``, or ``Response[None]`` when ``return_response=True``.
675
792
  """
676
793
  async with get_transaction(session):
677
- q = sql_delete(cls.model).where(and_(*filters))
678
- await session.execute(q)
794
+ result = await session.execute(select(cls.model).where(and_(*filters)))
795
+ objects = result.scalars().all()
796
+ for obj in objects:
797
+ await session.delete(obj)
679
798
  if return_response:
680
799
  return Response(data=None)
681
800
  return None
@@ -7,17 +7,19 @@ from enum import Enum
7
7
  from typing import Any, TypeVar
8
8
 
9
9
  from sqlalchemy import text
10
- from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
10
+ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
11
11
  from sqlalchemy.orm import DeclarativeBase
12
12
 
13
13
  from .exceptions import NotFoundError
14
14
 
15
15
  __all__ = [
16
16
  "LockMode",
17
+ "cleanup_tables",
18
+ "create_database",
17
19
  "create_db_context",
18
20
  "create_db_dependency",
19
- "lock_tables",
20
21
  "get_transaction",
22
+ "lock_tables",
21
23
  "wait_for_row_change",
22
24
  ]
23
25
 
@@ -188,6 +190,71 @@ async def lock_tables(
188
190
  yield session
189
191
 
190
192
 
193
+ async def create_database(
194
+ db_name: str,
195
+ *,
196
+ server_url: str,
197
+ ) -> None:
198
+ """Create a database.
199
+
200
+ Connects to *server_url* using ``AUTOCOMMIT`` isolation and issues a
201
+ ``CREATE DATABASE`` statement for *db_name*.
202
+
203
+ Args:
204
+ db_name: Name of the database to create.
205
+ server_url: URL used for server-level DDL (must point to an existing
206
+ database on the same server).
207
+
208
+ Example:
209
+ ```python
210
+ from fastapi_toolsets.db import create_database
211
+
212
+ SERVER_URL = "postgresql+asyncpg://postgres:postgres@localhost/postgres"
213
+ await create_database("myapp_test", server_url=SERVER_URL)
214
+ ```
215
+ """
216
+ engine = create_async_engine(server_url, isolation_level="AUTOCOMMIT")
217
+ try:
218
+ async with engine.connect() as conn:
219
+ await conn.execute(text(f"CREATE DATABASE {db_name}"))
220
+ finally:
221
+ await engine.dispose()
222
+
223
+
224
+ async def cleanup_tables(
225
+ session: AsyncSession,
226
+ base: type[DeclarativeBase],
227
+ ) -> None:
228
+ """Truncate all tables for fast between-test cleanup.
229
+
230
+ Executes a single ``TRUNCATE … RESTART IDENTITY CASCADE`` statement
231
+ across every table in *base*'s metadata, which is significantly faster
232
+ than dropping and re-creating tables between tests.
233
+
234
+ This is a no-op when the metadata contains no tables.
235
+
236
+ Args:
237
+ session: An active async database session.
238
+ base: SQLAlchemy DeclarativeBase class containing model metadata.
239
+
240
+ Example:
241
+ ```python
242
+ @pytest.fixture
243
+ async def db_session(worker_db_url):
244
+ async with create_db_session(worker_db_url, Base) as session:
245
+ yield session
246
+ await cleanup_tables(session, Base)
247
+ ```
248
+ """
249
+ tables = base.metadata.sorted_tables
250
+ if not tables:
251
+ return
252
+
253
+ table_names = ", ".join(f'"{t.name}"' for t in tables)
254
+ await session.execute(text(f"TRUNCATE {table_names} RESTART IDENTITY CASCADE"))
255
+ await session.commit()
256
+
257
+
191
258
  _M = TypeVar("_M", bound=DeclarativeBase)
192
259
 
193
260
 
@@ -51,7 +51,7 @@ def init_metrics(
51
51
  """
52
52
  for provider in registry.get_providers():
53
53
  logger.debug("Initialising metric provider '%s'", provider.name)
54
- provider.func()
54
+ registry._instances[provider.name] = provider.func()
55
55
 
56
56
  # Partition collectors and cache env check at startup — both are stable for the app lifetime.
57
57
  async_collectors = [
@@ -19,31 +19,11 @@ class Metric:
19
19
 
20
20
 
21
21
  class MetricsRegistry:
22
- """Registry for managing Prometheus metric providers and collectors.
23
-
24
- Example:
25
- ```python
26
- from prometheus_client import Counter, Gauge
27
- from fastapi_toolsets.metrics import MetricsRegistry
28
-
29
- metrics = MetricsRegistry()
30
-
31
- @metrics.register
32
- def http_requests():
33
- return Counter("http_requests_total", "Total HTTP requests", ["method", "status"])
34
-
35
- @metrics.register(name="db_pool")
36
- def database_pool_size():
37
- return Gauge("db_pool_size", "Database connection pool size")
38
-
39
- @metrics.register(collect=True)
40
- def collect_queue_depth(gauge=Gauge("queue_depth", "Current queue depth")):
41
- gauge.set(get_current_queue_depth())
42
- ```
43
- """
22
+ """Registry for managing Prometheus metric providers and collectors."""
44
23
 
45
24
  def __init__(self) -> None:
46
25
  self._metrics: dict[str, Metric] = {}
26
+ self._instances: dict[str, Any] = {}
47
27
 
48
28
  def register(
49
29
  self,
@@ -61,17 +41,6 @@ class MetricsRegistry:
61
41
  name: Metric name (defaults to function name).
62
42
  collect: If ``True``, the function is called on every scrape.
63
43
  If ``False`` (default), called once at init time.
64
-
65
- Example:
66
- ```python
67
- @metrics.register
68
- def my_counter():
69
- return Counter("my_counter", "A counter")
70
-
71
- @metrics.register(collect=True, name="queue")
72
- def collect_queue_depth():
73
- gauge.set(compute_depth())
74
- ```
75
44
  """
76
45
 
77
46
  def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
@@ -87,6 +56,25 @@ class MetricsRegistry:
87
56
  return decorator(func)
88
57
  return decorator
89
58
 
59
+ def get(self, name: str) -> Any:
60
+ """Return the metric instance created by a provider.
61
+
62
+ Args:
63
+ name: The metric name (defaults to the provider function name).
64
+
65
+ Raises:
66
+ KeyError: If the metric name is unknown or ``init_metrics`` has not
67
+ been called yet.
68
+ """
69
+ if name not in self._instances:
70
+ if name in self._metrics:
71
+ raise KeyError(
72
+ f"Metric '{name}' exists but has not been initialized yet. "
73
+ "Ensure init_metrics() has been called before accessing metric instances."
74
+ )
75
+ raise KeyError(f"Unknown metric '{name}'.")
76
+ return self._instances[name]
77
+
90
78
  def include_registry(self, registry: "MetricsRegistry") -> None:
91
79
  """Include another :class:`MetricsRegistry` into this one.
92
80
 
@@ -95,18 +83,6 @@ class MetricsRegistry:
95
83
 
96
84
  Raises:
97
85
  ValueError: If a metric name already exists in the current registry.
98
-
99
- Example:
100
- ```python
101
- main = MetricsRegistry()
102
- sub = MetricsRegistry()
103
-
104
- @sub.register
105
- def sub_metric():
106
- return Counter("sub_total", "Sub counter")
107
-
108
- main.include_registry(sub)
109
- ```
110
86
  """
111
87
  for metric_name, definition in registry._metrics.items():
112
88
  if metric_name in self._metrics:
@@ -1,12 +1,12 @@
1
1
  """Pytest helper utilities for FastAPI testing."""
2
2
 
3
3
  import os
4
+ import warnings
4
5
  from collections.abc import AsyncGenerator, Callable
5
6
  from contextlib import asynccontextmanager
6
7
  from typing import Any
7
8
 
8
9
  from httpx import ASGITransport, AsyncClient
9
- from sqlalchemy import text
10
10
  from sqlalchemy.engine import make_url
11
11
  from sqlalchemy.ext.asyncio import (
12
12
  AsyncSession,
@@ -15,7 +15,134 @@ from sqlalchemy.ext.asyncio import (
15
15
  )
16
16
  from sqlalchemy.orm import DeclarativeBase
17
17
 
18
- from ..db import create_db_context
18
+ from sqlalchemy import text
19
+
20
+ from ..db import (
21
+ cleanup_tables as _cleanup_tables,
22
+ create_database,
23
+ create_db_context,
24
+ )
25
+
26
+
27
+ async def cleanup_tables(
28
+ session: AsyncSession,
29
+ base: type[DeclarativeBase],
30
+ ) -> None:
31
+ """Truncate all tables for fast between-test cleanup.
32
+
33
+ .. deprecated::
34
+ Import ``cleanup_tables`` from ``fastapi_toolsets.db`` instead.
35
+ This re-export will be removed in v3.0.0.
36
+ """
37
+ warnings.warn(
38
+ "Importing cleanup_tables from fastapi_toolsets.pytest is deprecated "
39
+ "and will be removed in v3.0.0. "
40
+ "Use 'from fastapi_toolsets.db import cleanup_tables' instead.",
41
+ DeprecationWarning,
42
+ stacklevel=2,
43
+ )
44
+ await _cleanup_tables(session=session, base=base)
45
+
46
+
47
+ def _get_xdist_worker(default_test_db: str) -> str:
48
+ """Return the pytest-xdist worker name, or *default_test_db* when not running under xdist.
49
+
50
+ Reads the ``PYTEST_XDIST_WORKER`` environment variable that xdist sets
51
+ automatically in each worker process (e.g. ``"gw0"``, ``"gw1"``).
52
+ When xdist is not installed or not active, the variable is absent and
53
+ *default_test_db* is returned instead.
54
+
55
+ Args:
56
+ default_test_db: Fallback value returned when ``PYTEST_XDIST_WORKER``
57
+ is not set.
58
+ """
59
+ return os.environ.get("PYTEST_XDIST_WORKER", default_test_db)
60
+
61
+
62
+ def worker_database_url(database_url: str, default_test_db: str) -> str:
63
+ """Derive a per-worker database URL for pytest-xdist parallel runs.
64
+
65
+ Appends ``_{worker_name}`` to the database name so each xdist worker
66
+ operates on its own database. When not running under xdist,
67
+ ``_{default_test_db}`` is appended instead.
68
+
69
+ The worker name is read from the ``PYTEST_XDIST_WORKER`` environment
70
+ variable (set automatically by xdist in each worker process).
71
+
72
+ Args:
73
+ database_url: Original database connection URL.
74
+ default_test_db: Suffix appended to the database name when
75
+ ``PYTEST_XDIST_WORKER`` is not set.
76
+
77
+ Returns:
78
+ A database URL with a worker- or default-specific database name.
79
+ """
80
+ worker = _get_xdist_worker(default_test_db=default_test_db)
81
+
82
+ url = make_url(database_url)
83
+ url = url.set(database=f"{url.database}_{worker}")
84
+ return url.render_as_string(hide_password=False)
85
+
86
+
87
+ @asynccontextmanager
88
+ async def create_worker_database(
89
+ database_url: str,
90
+ default_test_db: str = "test_db",
91
+ ) -> AsyncGenerator[str, None]:
92
+ """Create and drop a per-worker database for pytest-xdist isolation.
93
+
94
+ Derives a worker-specific database URL using :func:`worker_database_url`,
95
+ then delegates to :func:`~fastapi_toolsets.db.create_database` to create
96
+ and drop it. Intended for use as a **session-scoped** fixture.
97
+
98
+ When running under xdist the database name is suffixed with the worker
99
+ name (e.g. ``_gw0``). Otherwise it is suffixed with *default_test_db*.
100
+
101
+ Args:
102
+ database_url: Original database connection URL (used as the server
103
+ connection and as the base for the worker database name).
104
+ default_test_db: Suffix appended to the database name when
105
+ ``PYTEST_XDIST_WORKER`` is not set. Defaults to ``"test_db"``.
106
+
107
+ Yields:
108
+ The worker-specific database URL.
109
+
110
+ Example:
111
+ ```python
112
+ from fastapi_toolsets.pytest import create_worker_database, create_db_session
113
+
114
+ DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost/test_db"
115
+
116
+ @pytest.fixture(scope="session")
117
+ async def worker_db_url():
118
+ async with create_worker_database(DATABASE_URL) as url:
119
+ yield url
120
+
121
+ @pytest.fixture
122
+ async def db_session(worker_db_url):
123
+ async with create_db_session(
124
+ worker_db_url, Base, cleanup=True
125
+ ) as session:
126
+ yield session
127
+ ```
128
+ """
129
+ worker_url = worker_database_url(
130
+ database_url=database_url, default_test_db=default_test_db
131
+ )
132
+ worker_db_name: str = make_url(worker_url).database # type: ignore[assignment]
133
+
134
+ engine = create_async_engine(database_url, isolation_level="AUTOCOMMIT")
135
+ try:
136
+ async with engine.connect() as conn:
137
+ await conn.execute(text(f"DROP DATABASE IF EXISTS {worker_db_name}"))
138
+ await create_database(db_name=worker_db_name, server_url=database_url)
139
+
140
+ yield worker_url
141
+
142
+ async with engine.connect() as conn:
143
+ await conn.execute(text(f"DROP DATABASE IF EXISTS {worker_db_name}"))
144
+ finally:
145
+ await engine.dispose()
19
146
 
20
147
 
21
148
  @asynccontextmanager
@@ -156,160 +283,3 @@ async def create_db_session(
156
283
  await conn.run_sync(base.metadata.drop_all)
157
284
  finally:
158
285
  await engine.dispose()
159
-
160
-
161
- def _get_xdist_worker(default_test_db: str) -> str:
162
- """Return the pytest-xdist worker name, or *default_test_db* when not running under xdist.
163
-
164
- Reads the ``PYTEST_XDIST_WORKER`` environment variable that xdist sets
165
- automatically in each worker process (e.g. ``"gw0"``, ``"gw1"``).
166
- When xdist is not installed or not active, the variable is absent and
167
- *default_test_db* is returned instead.
168
-
169
- Args:
170
- default_test_db: Fallback value returned when ``PYTEST_XDIST_WORKER``
171
- is not set.
172
- """
173
- return os.environ.get("PYTEST_XDIST_WORKER", default_test_db)
174
-
175
-
176
- def worker_database_url(database_url: str, default_test_db: str) -> str:
177
- """Derive a per-worker database URL for pytest-xdist parallel runs.
178
-
179
- Appends ``_{worker_name}`` to the database name so each xdist worker
180
- operates on its own database. When not running under xdist,
181
- ``_{default_test_db}`` is appended instead.
182
-
183
- The worker name is read from the ``PYTEST_XDIST_WORKER`` environment
184
- variable (set automatically by xdist in each worker process).
185
-
186
- Args:
187
- database_url: Original database connection URL.
188
- default_test_db: Suffix appended to the database name when
189
- ``PYTEST_XDIST_WORKER`` is not set.
190
-
191
- Returns:
192
- A database URL with a worker- or default-specific database name.
193
-
194
- Example:
195
- ```python
196
- # With PYTEST_XDIST_WORKER="gw0":
197
- url = worker_database_url(
198
- "postgresql+asyncpg://user:pass@localhost/test_db",
199
- default_test_db="test",
200
- )
201
- # "postgresql+asyncpg://user:pass@localhost/test_db_gw0"
202
-
203
- # Without PYTEST_XDIST_WORKER:
204
- url = worker_database_url(
205
- "postgresql+asyncpg://user:pass@localhost/test_db",
206
- default_test_db="test",
207
- )
208
- # "postgresql+asyncpg://user:pass@localhost/test_db_test"
209
- ```
210
- """
211
- worker = _get_xdist_worker(default_test_db=default_test_db)
212
-
213
- url = make_url(database_url)
214
- url = url.set(database=f"{url.database}_{worker}")
215
- return url.render_as_string(hide_password=False)
216
-
217
-
218
- @asynccontextmanager
219
- async def create_worker_database(
220
- database_url: str,
221
- default_test_db: str = "test_db",
222
- ) -> AsyncGenerator[str, None]:
223
- """Create and drop a per-worker database for pytest-xdist isolation.
224
-
225
- Intended for use as a **session-scoped** fixture. Connects to the server
226
- using the original *database_url* (with ``AUTOCOMMIT`` isolation for DDL),
227
- creates a dedicated database for the worker, and yields the worker-specific
228
- URL. On cleanup the worker database is dropped.
229
-
230
- When running under xdist the database name is suffixed with the worker
231
- name (e.g. ``_gw0``). Otherwise it is suffixed with *default_test_db*.
232
-
233
- Args:
234
- database_url: Original database connection URL.
235
- default_test_db: Suffix appended to the database name when
236
- ``PYTEST_XDIST_WORKER`` is not set. Defaults to ``"test_db"``.
237
-
238
- Yields:
239
- The worker-specific database URL.
240
-
241
- Example:
242
- ```python
243
- from fastapi_toolsets.pytest import (
244
- create_worker_database, create_db_session,
245
- )
246
-
247
- DATABASE_URL = "postgresql+asyncpg://postgres:postgres@localhost/test_db"
248
-
249
- @pytest.fixture(scope="session")
250
- async def worker_db_url():
251
- async with create_worker_database(DATABASE_URL) as url:
252
- yield url
253
-
254
- @pytest.fixture
255
- async def db_session(worker_db_url):
256
- async with create_db_session(
257
- worker_db_url, Base, cleanup=True
258
- ) as session:
259
- yield session
260
- ```
261
- """
262
- worker_url = worker_database_url(
263
- database_url=database_url, default_test_db=default_test_db
264
- )
265
- worker_db_name = make_url(worker_url).database
266
-
267
- engine = create_async_engine(
268
- database_url,
269
- isolation_level="AUTOCOMMIT",
270
- )
271
- try:
272
- async with engine.connect() as conn:
273
- await conn.execute(text(f"DROP DATABASE IF EXISTS {worker_db_name}"))
274
- await conn.execute(text(f"CREATE DATABASE {worker_db_name}"))
275
-
276
- yield worker_url
277
-
278
- async with engine.connect() as conn:
279
- await conn.execute(text(f"DROP DATABASE IF EXISTS {worker_db_name}"))
280
- finally:
281
- await engine.dispose()
282
-
283
-
284
- async def cleanup_tables(
285
- session: AsyncSession,
286
- base: type[DeclarativeBase],
287
- ) -> None:
288
- """Truncate all tables for fast between-test cleanup.
289
-
290
- Executes a single ``TRUNCATE … RESTART IDENTITY CASCADE`` statement
291
- across every table in *base*'s metadata, which is significantly faster
292
- than dropping and re-creating tables between tests.
293
-
294
- This is a no-op when the metadata contains no tables.
295
-
296
- Args:
297
- session: An active async database session.
298
- base: SQLAlchemy DeclarativeBase class containing model metadata.
299
-
300
- Example:
301
- ```python
302
- @pytest.fixture
303
- async def db_session(worker_db_url):
304
- async with create_db_session(worker_db_url, Base) as session:
305
- yield session
306
- await cleanup_tables(session, Base)
307
- ```
308
- """
309
- tables = base.metadata.sorted_tables
310
- if not tables:
311
- return
312
-
313
- table_names = ", ".join(f'"{t.name}"' for t in tables)
314
- await session.execute(text(f"TRUNCATE {table_names} RESTART IDENTITY CASCADE"))
315
- await session.commit()