cuneus 0.2.8__tar.gz → 0.2.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {cuneus-0.2.8 → cuneus-0.2.10}/PKG-INFO +5 -1
  2. cuneus-0.2.10/examples/my_app/main.py +6 -0
  3. cuneus-0.2.10/examples/pyproject.toml +7 -0
  4. {cuneus-0.2.8 → cuneus-0.2.10}/pyproject.toml +9 -5
  5. cuneus-0.2.10/src/cuneus/cli.py +51 -0
  6. {cuneus-0.2.8 → cuneus-0.2.10}/src/cuneus/core/application.py +11 -15
  7. {cuneus-0.2.8 → cuneus-0.2.10}/src/cuneus/core/exceptions.py +2 -6
  8. {cuneus-0.2.8 → cuneus-0.2.10}/src/cuneus/core/extensions.py +7 -0
  9. {cuneus-0.2.8 → cuneus-0.2.10}/src/cuneus/core/logging.py +1 -1
  10. {cuneus-0.2.8 → cuneus-0.2.10}/src/cuneus/core/settings.py +20 -4
  11. cuneus-0.2.10/src/cuneus/dependencies.py +79 -0
  12. cuneus-0.2.10/src/cuneus/ext/database.py +278 -0
  13. cuneus-0.2.10/src/cuneus/ext/health.py +126 -0
  14. cuneus-0.2.10/src/cuneus/ext/otel.py +279 -0
  15. cuneus-0.2.10/src/cuneus/ext/server.py +54 -0
  16. cuneus-0.2.10/src/cuneus/py.typed +0 -0
  17. cuneus-0.2.10/src/cuneus/utils.py +12 -0
  18. cuneus-0.2.10/tests/cli/test_cli.py +141 -0
  19. cuneus-0.2.10/tests/cli/testapp/__init__.py +0 -0
  20. cuneus-0.2.10/tests/cli/testapp/main.py +23 -0
  21. cuneus-0.2.10/tests/cli/testapp/pyproject.toml +9 -0
  22. cuneus-0.2.10/tests/ext/test_database.py +96 -0
  23. cuneus-0.2.10/tests/ext/test_health.py +109 -0
  24. cuneus-0.2.10/tests/ext/test_otel.py +130 -0
  25. cuneus-0.2.10/tests/test_dependencies.py +111 -0
  26. {cuneus-0.2.8 → cuneus-0.2.10}/tests/test_exceptions.py +1 -1
  27. {cuneus-0.2.8 → cuneus-0.2.10}/tests/test_integration.py +18 -8
  28. cuneus-0.2.10/tests/test_utils.py +23 -0
  29. {cuneus-0.2.8 → cuneus-0.2.10}/uv.lock +786 -114
  30. cuneus-0.2.8/src/cuneus/cli.py +0 -143
  31. cuneus-0.2.8/src/cuneus/ext/health.py +0 -129
  32. cuneus-0.2.8/tests/test_cli.py +0 -252
  33. {cuneus-0.2.8 → cuneus-0.2.10}/.gitignore +0 -0
  34. {cuneus-0.2.8 → cuneus-0.2.10}/.python-version +0 -0
  35. {cuneus-0.2.8 → cuneus-0.2.10}/Makefile +0 -0
  36. {cuneus-0.2.8 → cuneus-0.2.10}/README.md +0 -0
  37. {cuneus-0.2.8/src/cuneus/core → cuneus-0.2.10/examples/my_app}/__init__.py +0 -0
  38. {cuneus-0.2.8 → cuneus-0.2.10}/src/cuneus/__init__.py +0 -0
  39. {cuneus-0.2.8/src/cuneus/ext → cuneus-0.2.10/src/cuneus/core}/__init__.py +0 -0
  40. /cuneus-0.2.8/src/cuneus/py.typed → /cuneus-0.2.10/src/cuneus/ext/__init__.py +0 -0
  41. {cuneus-0.2.8 → cuneus-0.2.10}/tests/test_extensions.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cuneus
3
- Version: 0.2.8
3
+ Version: 0.2.10
4
4
  Summary: ASGI application wrapper
5
5
  Project-URL: Homepage, https://github.com/rmyers/cuneus
6
6
  Project-URL: Documentation, https://github.com/rmyers/cuneus#readme
@@ -24,11 +24,15 @@ Requires-Dist: alembic>=1.13.0; extra == 'database'
24
24
  Requires-Dist: asyncpg>=0.29.0; extra == 'database'
25
25
  Requires-Dist: sqlalchemy[asyncio]>=2.0; extra == 'database'
26
26
  Provides-Extra: dev
27
+ Requires-Dist: aiosqlite[dev]>=0.22.1; extra == 'dev'
27
28
  Requires-Dist: alembic>=1.13.0; extra == 'dev'
28
29
  Requires-Dist: asgi-lifespan>=2.1.0; extra == 'dev'
29
30
  Requires-Dist: asyncpg>=0.29.0; extra == 'dev'
30
31
  Requires-Dist: httpx>=0.27; extra == 'dev'
31
32
  Requires-Dist: mypy>=1.8; extra == 'dev'
33
+ Requires-Dist: opentelemetry-api; extra == 'dev'
34
+ Requires-Dist: opentelemetry-instrumentation; extra == 'dev'
35
+ Requires-Dist: opentelemetry-sdk; extra == 'dev'
32
36
  Requires-Dist: pytest-asyncio>=0.23; extra == 'dev'
33
37
  Requires-Dist: pytest-cov>=4.0; extra == 'dev'
34
38
  Requires-Dist: pytest-mock; extra == 'dev'
@@ -0,0 +1,6 @@
1
+ from cuneus import build_app
2
+
3
+
4
+ app, cli, lifespan = build_app()
5
+
6
+ __all__ = ["app", "cli", "lifespan"]
@@ -0,0 +1,7 @@
1
+ [project]
2
+ name = "test-project"
3
+ version = "0.0.1"
4
+
5
+ [tool.cuneus]
6
+ app_module = "my_app.main:app"
7
+ cli_module = "my_app.main:cli"
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "cuneus"
3
- version = "0.2.8"
3
+ version = "0.2.10"
4
4
  description = "ASGI application wrapper"
5
5
  readme = "README.md"
6
6
  authors = [{ name = "Robert Myers", email = "robert@julython.org" }]
@@ -18,20 +18,24 @@ dependencies = [
18
18
  ]
19
19
 
20
20
  [project.optional-dependencies]
21
- database = ["sqlalchemy[asyncio]>=2.0", "asyncpg>=0.29.0", "alembic>=1.13.0"]
22
- redis = ["redis>=5.0"]
23
- all = ["cuneus[database,redis]"]
24
21
  dev = [
22
+ "asgi-lifespan>=2.1.0",
23
+ "aiosqlite[dev]>=0.22.1",
25
24
  "cuneus[all]",
25
+ "opentelemetry-sdk",
26
+ "opentelemetry-api",
27
+ "opentelemetry-instrumentation",
26
28
  "pytest>=8.0",
27
29
  "pytest-asyncio>=0.23",
28
30
  "pytest-cov>=4.0",
29
31
  "pytest-mock",
30
32
  "httpx>=0.27",
31
- "asgi-lifespan>=2.1.0",
32
33
  "ruff>=0.3",
33
34
  "mypy>=1.8",
34
35
  ]
36
+ database = ["sqlalchemy[asyncio]>=2.0", "asyncpg>=0.29.0", "alembic>=1.13.0"]
37
+ redis = ["redis>=5.0"]
38
+ all = ["cuneus[database,redis]"]
35
39
 
36
40
  [project.scripts]
37
41
  cuneus = "cuneus.cli:main"
@@ -0,0 +1,51 @@
1
+ """Cuneus CLI entry point."""
2
+
3
+ from typing import Any, cast
4
+
5
+ import click
6
+
7
+ from .core.settings import Settings, ensure_project_in_path
8
+ from .utils import import_from_string
9
+
10
+
11
+ def get_user_cli(config: Settings = Settings()) -> click.Group | None:
12
+ """Load CLI from config."""
13
+ try:
14
+ return cast(click.Group, import_from_string(config.cli_module))
15
+ except (ImportError, AttributeError) as e:
16
+ click.echo(f"Warning: Could not load CLI from {config.cli_module}: {e}", err=True)
17
+ return None
18
+
19
+
20
+ class CuneusCLI(click.Group):
21
+ """Delegates to the app's CLI from config."""
22
+
23
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
24
+ super().__init__(*args, **kwargs)
25
+ self._user_cli: click.Group | None = None
26
+ self._user_cli_loaded = False
27
+
28
+ @property
29
+ def user_cli(self) -> click.Group | None:
30
+ if not self._user_cli_loaded:
31
+ self._user_cli = get_user_cli()
32
+ self._user_cli_loaded = True
33
+ return self._user_cli
34
+
35
+ def list_commands(self, ctx: click.Context) -> list[str]:
36
+ if self.user_cli:
37
+ return self.user_cli.list_commands(ctx)
38
+ return []
39
+
40
+ def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None:
41
+ if self.user_cli:
42
+ return self.user_cli.get_command(ctx, cmd_name)
43
+ return None
44
+
45
+
46
+ ensure_project_in_path()
47
+ main = CuneusCLI(help="Cuneus CLI - FastAPI application framework")
48
+
49
+
50
+ if __name__ == "__main__":
51
+ main()
@@ -19,8 +19,9 @@ from starlette.middleware import Middleware
19
19
  from .settings import Settings
20
20
  from .exceptions import ExceptionExtension
21
21
  from .logging import LoggingExtension
22
- from .extensions import Extension, HasCLI, HasExceptionHandler, HasMiddleware
22
+ from .extensions import Extension, HasCLI, HasExceptionHandler, HasMiddleware, HasRoutes
23
23
  from ..ext.health import HealthExtension
24
+ from ..ext.server import ServerExtension
24
25
 
25
26
  logger = structlog.stdlib.get_logger("cuneus")
26
27
 
@@ -30,6 +31,7 @@ DEFAULTS = (
30
31
  LoggingExtension,
31
32
  HealthExtension,
32
33
  ExceptionExtension,
34
+ ServerExtension,
33
35
  )
34
36
 
35
37
 
@@ -39,9 +41,7 @@ class ExtensionConflictError(Exception):
39
41
  pass
40
42
 
41
43
 
42
- def _instantiate_extension(
43
- ext: ExtensionInput, settings: Settings | None = None
44
- ) -> Extension:
44
+ def _instantiate_extension(ext: ExtensionInput, settings: Settings | None = None) -> Extension:
45
45
  if isinstance(ext, type) or callable(ext):
46
46
  try:
47
47
  return ext(settings=settings)
@@ -56,7 +56,7 @@ def build_app(
56
56
  settings: Settings | None = None,
57
57
  include_defaults: bool = True,
58
58
  **fastapi_kwargs: Any,
59
- ) -> tuple[FastAPI, click.Group]:
59
+ ) -> tuple[FastAPI, click.Group, svcs.fastapi.lifespan]:
60
60
  """
61
61
  Build a FastAPI with extensions preconfigured.
62
62
 
@@ -93,17 +93,9 @@ def build_app(
93
93
 
94
94
  all_extensions = [_instantiate_extension(ext, settings) for ext in all_inputs]
95
95
 
96
- @click.group()
97
- @click.pass_context
98
- def app_cli(ctx: click.Context) -> None:
99
- """Application CLI."""
100
- ctx.ensure_object(dict)
101
-
102
96
  @svcs.fastapi.lifespan
103
97
  @asynccontextmanager
104
- async def lifespan(
105
- app: FastAPI, registry: svcs.Registry
106
- ) -> AsyncIterator[dict[str, Any]]:
98
+ async def lifespan(app: FastAPI, registry: svcs.Registry) -> AsyncIterator[dict[str, Any]]:
107
99
  async with AsyncExitStack() as stack:
108
100
  state: dict[str, Any] = {}
109
101
 
@@ -121,6 +113,7 @@ def build_app(
121
113
 
122
114
  # Parse extensions for middleware and cli commands
123
115
  middleware: list[Middleware] = []
116
+ app_cli = click.Group()
124
117
 
125
118
  for ext in all_extensions:
126
119
  ext_name = ext.__class__.__name__
@@ -139,5 +132,8 @@ def build_app(
139
132
  if isinstance(ext, HasExceptionHandler):
140
133
  logger.debug(f"Loading exception handlers from {ext_name}")
141
134
  ext.add_exception_handler(app)
135
+ if isinstance(ext, HasRoutes):
136
+ logger.debug(f"Loading routes from {ext_name}")
137
+ ext.add_routes(app)
142
138
 
143
- return app, app_cli
139
+ return app, app_cli, lifespan
@@ -160,9 +160,7 @@ class ExceptionExtension(BaseExtension):
160
160
  app.add_exception_handler(AppException, self._handle_app_exception) # type: ignore[arg-type]
161
161
  app.add_exception_handler(Exception, self._handle_unexpected_exception)
162
162
 
163
- def _handle_app_exception(
164
- self, request: Request, exc: AppException
165
- ) -> JSONResponse:
163
+ def _handle_app_exception(self, request: Request, exc: AppException) -> JSONResponse:
166
164
  if exc.status_code >= 500 and self.settings.log_server_errors:
167
165
  log.exception("server_error", error_code=exc.error_code)
168
166
  else:
@@ -180,9 +178,7 @@ class ExceptionExtension(BaseExtension):
180
178
  headers=headers,
181
179
  )
182
180
 
183
- def _handle_unexpected_exception(
184
- self, request: Request, exc: Exception
185
- ) -> JSONResponse:
181
+ def _handle_unexpected_exception(self, request: Request, exc: Exception) -> JSONResponse:
186
182
  log.exception("unexpected_error", exc_info=exc)
187
183
  response: dict[str, Any] = {
188
184
  "error": {
@@ -67,6 +67,13 @@ class HasExceptionHandler(Protocol):
67
67
  def add_exception_handler(self, app: FastAPI) -> None: ... # pragma: no cover
68
68
 
69
69
 
70
+ @runtime_checkable
71
+ class HasRoutes(Protocol):
72
+ """Extension that provides routes."""
73
+
74
+ def add_routes(self, app: FastAPI) -> None: ... # pragma: no cover
75
+
76
+
70
77
  class BaseExtension:
71
78
  """
72
79
  Base class for extensions with explicit startup/shutdown hooks.
@@ -43,7 +43,7 @@ def configure_structlog(settings: Settings | None = None) -> None:
43
43
  term_width = shutil.get_terminal_size().columns
44
44
  pad_event = term_width - 36
45
45
  renderer: structlog.types.Processor = structlog.dev.ConsoleRenderer(
46
- colors=True, pad_event=pad_event
46
+ colors=True, pad_event_to=pad_event
47
47
  )
48
48
 
49
49
  # Configure structlog
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
- import logging
3
+ import pathlib
4
+ import sys
4
5
  from pydantic_settings import (
5
6
  BaseSettings,
6
7
  PydanticBaseSettingsSource,
@@ -8,8 +9,6 @@ from pydantic_settings import (
8
9
  SettingsConfigDict,
9
10
  )
10
11
 
11
- logger = logging.getLogger(__name__)
12
-
13
12
  DEFAULT_TOOL_NAME = "cuneus"
14
13
 
15
14
 
@@ -40,7 +39,6 @@ class CuneusBaseSettings(BaseSettings):
40
39
 
41
40
 
42
41
  class Settings(CuneusBaseSettings):
43
-
44
42
  model_config = SettingsConfigDict(
45
43
  env_file=".env",
46
44
  env_file_encoding="utf-8",
@@ -64,3 +62,21 @@ class Settings(CuneusBaseSettings):
64
62
  # health
65
63
  health_enabled: bool = True
66
64
  health_prefix: str = "/healthz"
65
+
66
+ @classmethod
67
+ def get_project_root(cls) -> pathlib.Path:
68
+ """
69
+ Get the project root by inspecting where pydantic-settings
70
+ found the pyproject.toml file.
71
+ """
72
+ source = PyprojectTomlConfigSettingsSource(
73
+ cls,
74
+ )
75
+ return source.toml_file_path.parent
76
+
77
+
78
+ def ensure_project_in_path() -> None:
79
+ """Add project root to sys.path if not already present."""
80
+ project_root = str(Settings.get_project_root())
81
+ if project_root not in sys.path: # pragma: no branch
82
+ sys.path.insert(0, project_root)
@@ -0,0 +1,79 @@
1
+ # cuneus/core/dependencies.py
2
+ from __future__ import annotations
3
+
4
+ import importlib
5
+ import logging
6
+ from dataclasses import dataclass
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ @dataclass
12
+ class Dependency:
13
+ """A required dependency with install hint."""
14
+
15
+ import_name: str
16
+ package_name: str | None = None # pip package name if different from import
17
+
18
+ @property
19
+ def pip_name(self) -> str:
20
+ return self.package_name or self.import_name
21
+
22
+
23
+ class MissingDependencyError(ImportError):
24
+ """Raised when required dependencies are not installed."""
25
+
26
+ def __init__(self, extension: str, missing: list[Dependency]):
27
+ self.extension = extension
28
+ self.missing = missing
29
+ packages = " ".join(d.pip_name for d in missing)
30
+ super().__init__(
31
+ f"{extension} requires additional dependencies. Install with: uv add {packages}"
32
+ )
33
+
34
+
35
+ def check_dependencies(extension: str, *deps: Dependency) -> None:
36
+ """
37
+ Check that dependencies are installed, raise helpful error if not.
38
+
39
+ Usage:
40
+ from cuneus.core.dependencies import check_dependencies, Dependency
41
+
42
+ check_dependencies(
43
+ "DatabaseExtension",
44
+ Dependency("sqlalchemy"),
45
+ Dependency("asyncpg"),
46
+ )
47
+ """
48
+ missing = []
49
+ for dep in deps:
50
+ try:
51
+ importlib.import_module(dep.import_name)
52
+ except ImportError:
53
+ missing.append(dep)
54
+
55
+ if missing:
56
+ raise MissingDependencyError(extension, missing)
57
+
58
+
59
+ def warn_missing(extension: str, *deps: Dependency) -> list[Dependency]:
60
+ """
61
+ Check dependencies but only warn, don't raise. Returns list of missing.
62
+
63
+ Useful for optional features within an extension.
64
+ """
65
+ missing = []
66
+ for dep in deps:
67
+ try:
68
+ importlib.import_module(dep.import_name)
69
+ except ImportError:
70
+ missing.append(dep)
71
+
72
+ if missing:
73
+ packages = " ".join(d.pip_name for d in missing)
74
+ logger.warning(
75
+ f"{extension}: optional dependencies not installed. "
76
+ f"Some features disabled. Install with: uv add {packages}"
77
+ )
78
+
79
+ return missing
@@ -0,0 +1,278 @@
1
+ # cuneus/ext/database.py
2
+ from __future__ import annotations
3
+
4
+ import logging
5
+ from contextlib import asynccontextmanager
6
+ from pathlib import Path
7
+ from typing import Any, AsyncIterator
8
+
9
+ import click
10
+ import svcs
11
+ from fastapi import FastAPI
12
+ from pydantic import Field, SecretStr, computed_field
13
+ from pydantic_settings import SettingsConfigDict
14
+ from structlog.stdlib import get_logger
15
+
16
+ from ..core.extensions import BaseExtension, HasCLI
17
+ from ..core.settings import CuneusBaseSettings, DEFAULT_TOOL_NAME
18
+ from ..dependencies import Dependency, check_dependencies
19
+
20
+ check_dependencies(
21
+ "cuneus.ext.database",
22
+ Dependency("sqlalchemy"),
23
+ )
24
+
25
+ from sqlalchemy import URL, make_url, text
26
+ from sqlalchemy.ext.asyncio import (
27
+ AsyncEngine,
28
+ AsyncSession,
29
+ async_sessionmaker,
30
+ create_async_engine,
31
+ )
32
+
33
+ logger = get_logger(__name__)
34
+
35
+
36
+ class DatabaseSettings(CuneusBaseSettings):
37
+ """Database configuration."""
38
+
39
+ model_config = SettingsConfigDict(
40
+ env_prefix="DATABASE_",
41
+ env_file=".env",
42
+ env_file_encoding="utf-8",
43
+ extra="ignore",
44
+ pyproject_toml_depth=2,
45
+ pyproject_toml_table_header=("tool", DEFAULT_TOOL_NAME, "database"),
46
+ )
47
+
48
+ # Option 1: Full URL (takes precedence if set)
49
+ url: str | None = None
50
+
51
+ # Option 2: Individual parts
52
+ driver: str = "postgresql+asyncpg"
53
+ host: str = "localhost"
54
+ port: int = 5432
55
+ name: str = "app"
56
+ username: str | None = None
57
+ password: SecretStr | None = None
58
+
59
+ # Pool settings
60
+ pool_size: int = 5
61
+ pool_max_overflow: int = 10
62
+ pool_recycle: int = 3600
63
+ echo: bool = False
64
+
65
+ # Alembic
66
+ alembic_config: Path = Path("alembic.ini")
67
+
68
+ @computed_field
69
+ @property
70
+ def url_parsed(self) -> URL:
71
+ """Get SQLAlchemy URL, either from url string or constructed from parts."""
72
+ if self.url:
73
+ return make_url(self.url)
74
+
75
+ needs_opts = "sqlite" not in self.driver
76
+ password_value = self.password.get_secret_value() if self.password else None
77
+ password = password_value if needs_opts else None
78
+
79
+ return URL.create(
80
+ drivername=self.driver,
81
+ username=self.username if needs_opts else None,
82
+ password=password,
83
+ host=self.host if needs_opts else None,
84
+ port=self.port if needs_opts else None,
85
+ database=self.name,
86
+ )
87
+
88
+ @computed_field
89
+ @property
90
+ def url_redacted(self) -> str:
91
+ """URL safe for logging (password hidden)."""
92
+ return self.url_parsed.render_as_string(hide_password=True)
93
+
94
+
95
+ class DatabaseExtension(BaseExtension, HasCLI):
96
+ """
97
+ Database extension providing AsyncSession via svcs.
98
+
99
+ Registers:
100
+ - AsyncEngine: The SQLAlchemy async engine
101
+ - async_sessionmaker: Factory for creating sessions
102
+ - AsyncSession: Request-scoped session (via factory)
103
+
104
+ CLI Commands:
105
+ - db upgrade [revision]: Run migrations
106
+ - db downgrade [revision]: Rollback migrations
107
+ - db revision -m "message": Create new migration
108
+ - db current: Show current revision
109
+ - db history: Show migration history
110
+ - db check: Check database connectivity
111
+
112
+ Configuration (env or pyproject.toml [tool.cuneus.database]):
113
+ DATABASE_URL: Connection string
114
+ DATABASE_POOL_SIZE: Connection pool size (default: 5)
115
+ DATABASE_POOL_MAX_OVERFLOW: Max overflow connections (default: 10)
116
+ DATABASE_POOL_RECYCLE: Connection recycle time in seconds (default: 3600)
117
+ DATABASE_ECHO: Echo SQL statements (default: false)
118
+ DATABASE_ALEMBIC_CONFIG: Path to alembic.ini (default: alembic.ini)
119
+ """
120
+
121
+ _session_factory: async_sessionmaker[AsyncSession]
122
+ _engine: AsyncEngine
123
+
124
+ def __init__(self, settings: DatabaseSettings | None = None):
125
+ self.settings = settings or DatabaseSettings()
126
+
127
+ @asynccontextmanager
128
+ async def register(
129
+ self, registry: svcs.Registry, app: FastAPI
130
+ ) -> AsyncIterator[dict[str, Any]]:
131
+ self._engine = create_async_engine(
132
+ self.settings.url_parsed,
133
+ # pool_size=self.settings.pool_size,
134
+ # max_overflow=self.settings.pool_max_overflow,
135
+ pool_recycle=self.settings.pool_recycle,
136
+ echo=self.settings.echo,
137
+ )
138
+
139
+ self._session_factory = async_sessionmaker(
140
+ self._engine,
141
+ class_=AsyncSession,
142
+ expire_on_commit=False,
143
+ )
144
+
145
+ registry.register_value(AsyncEngine, self._engine, ping=self._check)
146
+ registry.register_value(async_sessionmaker, self._session_factory)
147
+
148
+ @asynccontextmanager
149
+ async def session_factory() -> AsyncIterator[AsyncSession]:
150
+ async with self._session_factory() as session:
151
+ try:
152
+ yield session
153
+ await session.commit()
154
+ except Exception:
155
+ await session.rollback()
156
+ raise
157
+
158
+ registry.register_factory(AsyncSession, session_factory)
159
+
160
+ logger.info("Database started", extra={"url": self.settings.url_redacted})
161
+
162
+ try:
163
+ yield {
164
+ "db_engine": self._engine,
165
+ "db_session_factory": self._session_factory,
166
+ }
167
+ finally:
168
+ await self._engine.dispose()
169
+ logger.info("Database shutdown")
170
+
171
+ async def _check(self):
172
+ engine = create_async_engine(self.settings.url_parsed)
173
+ try:
174
+ async with engine.connect() as conn:
175
+ await conn.execute(text("SELECT 1"))
176
+ finally:
177
+ await engine.dispose()
178
+
179
+ def register_cli(self, cli_group: click.Group) -> None:
180
+ settings = self.settings
181
+
182
+ @cli_group.group()
183
+ def db():
184
+ """Database management commands."""
185
+ pass
186
+
187
+ @db.command()
188
+ @click.argument("revision", default="head")
189
+ def upgrade(revision: str):
190
+ """Upgrade database to revision (default: head)."""
191
+ _run_alembic_cmd("upgrade", settings.alembic_config, revision=revision)
192
+
193
+ @db.command()
194
+ @click.argument("revision", default="-1")
195
+ def downgrade(revision: str):
196
+ """Downgrade database to revision (default: -1)."""
197
+ _run_alembic_cmd("downgrade", settings.alembic_config, revision=revision)
198
+
199
+ @db.command()
200
+ @click.option("-m", "--message", required=True, help="Migration message")
201
+ @click.option("--autogenerate/--no-autogenerate", default=True)
202
+ def revision(message: str, autogenerate: bool):
203
+ """Create a new migration revision."""
204
+ _run_alembic_cmd(
205
+ "revision",
206
+ settings.alembic_config,
207
+ message=message,
208
+ autogenerate=autogenerate,
209
+ )
210
+
211
+ @db.command()
212
+ def current():
213
+ """Show current database revision."""
214
+ _run_alembic_cmd("current", settings.alembic_config)
215
+
216
+ @db.command()
217
+ def history():
218
+ """Show migration history."""
219
+ _run_alembic_cmd("history", settings.alembic_config)
220
+
221
+ @db.command()
222
+ @click.argument("template", default="async")
223
+ def init():
224
+ """
225
+ Create a new alembic setup by default this will use the async template
226
+ """
227
+
228
+ @db.command()
229
+ @click.pass_context
230
+ def check(ctx: click.Context):
231
+ """Check database connectivity."""
232
+ import asyncio
233
+
234
+ async def _check():
235
+ engine = create_async_engine(settings.url_parsed)
236
+ try:
237
+ async with engine.connect() as conn:
238
+ await conn.execute(text("SELECT 1"))
239
+ click.echo("✓ Database connection OK")
240
+ except Exception as e:
241
+ print(e)
242
+ click.echo(f"✗ Database connection failed: {e}", err=True)
243
+ ctx.exit(1)
244
+ finally:
245
+ await engine.dispose()
246
+
247
+ asyncio.run(_check())
248
+
249
+
250
+ def _run_alembic_cmd(
251
+ cmd: str,
252
+ config_path: Path,
253
+ revision: str | None = None,
254
+ message: str | None = None,
255
+ autogenerate: bool = False,
256
+ ) -> None:
257
+ """Run an alembic command."""
258
+ from alembic import command
259
+ from alembic.config import Config
260
+
261
+ if not config_path.exists():
262
+ raise click.ClickException(f"Alembic config not found: {config_path}")
263
+
264
+ cfg = Config(str(config_path))
265
+
266
+ match cmd:
267
+ case "upgrade":
268
+ command.upgrade(cfg, revision or "head")
269
+ case "downgrade":
270
+ command.downgrade(cfg, revision or "-1")
271
+ case "revision":
272
+ command.revision(cfg, message=message, autogenerate=autogenerate)
273
+ case "current":
274
+ command.current(cfg)
275
+ case "history":
276
+ command.history(cfg)
277
+ case _:
278
+ raise click.ClickException(f"Unknown command: {cmd}")