contentgrid-extension-helpers 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. contentgrid_extension_helpers/__init__.py +3 -0
  2. contentgrid_extension_helpers/authentication/__init__.py +2 -2
  3. contentgrid_extension_helpers/authentication/user.py +28 -2
  4. contentgrid_extension_helpers/config.py +40 -0
  5. contentgrid_extension_helpers/dependencies/authentication/user.py +73 -0
  6. contentgrid_extension_helpers/dependencies/clients/contentgrid/__init__.py +3 -0
  7. contentgrid_extension_helpers/dependencies/clients/contentgrid/client_factory.py +71 -0
  8. contentgrid_extension_helpers/dependencies/clients/contentgrid/extension_flow_factory.py +85 -0
  9. contentgrid_extension_helpers/dependencies/clients/contentgrid/service_account_factory.py +87 -0
  10. contentgrid_extension_helpers/dependencies/sqlalch/__init__.py +0 -0
  11. contentgrid_extension_helpers/dependencies/sqlalch/db/__init__.py +14 -0
  12. contentgrid_extension_helpers/dependencies/sqlalch/db/base_factory.py +107 -0
  13. contentgrid_extension_helpers/dependencies/sqlalch/db/postgres.py +104 -0
  14. contentgrid_extension_helpers/dependencies/sqlalch/db/sqlite.py +43 -0
  15. contentgrid_extension_helpers/dependencies/sqlalch/repositories/__init__.py +1 -0
  16. contentgrid_extension_helpers/dependencies/sqlalch/repositories/base_repository.py +52 -0
  17. contentgrid_extension_helpers/exceptions.py +48 -0
  18. contentgrid_extension_helpers/logging/__init__.py +1 -1
  19. contentgrid_extension_helpers/logging/json_logging.py +2 -2
  20. contentgrid_extension_helpers/middleware/exception_middleware.py +124 -0
  21. contentgrid_extension_helpers/problem_response.py +45 -0
  22. contentgrid_extension_helpers/responses/__init__.py +0 -0
  23. contentgrid_extension_helpers/responses/hal.py +212 -0
  24. contentgrid_extension_helpers/structured_output/model_deny.py +45 -0
  25. {contentgrid_extension_helpers-0.0.1.dist-info → contentgrid_extension_helpers-0.0.3.dist-info}/METADATA +8 -2
  26. contentgrid_extension_helpers-0.0.3.dist-info/RECORD +30 -0
  27. {contentgrid_extension_helpers-0.0.1.dist-info → contentgrid_extension_helpers-0.0.3.dist-info}/WHEEL +1 -1
  28. contentgrid_extension_helpers-0.0.1.dist-info/RECORD +0 -11
  29. {contentgrid_extension_helpers-0.0.1.dist-info → contentgrid_extension_helpers-0.0.3.dist-info/licenses}/LICENSE +0 -0
  30. {contentgrid_extension_helpers-0.0.1.dist-info → contentgrid_extension_helpers-0.0.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,3 @@
1
+ from .exceptions import LLMDenyException, ExtensionHelperException, IllegalActivityError, MissingInputError, InjectionError, MalformedInputError, NotRelatedError, SensitiveInformationError, SecurityError # noqa: F401
2
+ from .middleware.exception_middleware import catch_exceptions_middleware # noqa: F401
3
+ from .structured_output.model_deny import DenyReason, ModelDeny # noqa: F401
@@ -1,2 +1,2 @@
1
- from .oidc import create_current_user_dependency, decode_and_verify_token, create_oauth2_scheme, get_jwks_client
2
- from .user import ContentGridUser
1
+ from .oidc import create_current_user_dependency, decode_and_verify_token, create_oauth2_scheme, get_oauth_jwks_client, get_oidc_jwks_client # noqa: F401
2
+ from .user import ContentGridUser # noqa: F401
@@ -1,4 +1,5 @@
1
- from pydantic import BaseModel
1
+ from typing import List
2
+ from pydantic import BaseModel, Field
2
3
 
3
4
  class ContentGridUser(BaseModel):
4
5
  sub: str
@@ -6,4 +7,29 @@ class ContentGridUser(BaseModel):
6
7
  exp: float
7
8
  name: str | None = None
8
9
  email: str | None = None
9
- access_token: str
10
+ access_token: str
11
+ domains : List[str] = Field(validation_alias="context:application:domains")
12
+ application_id : str = Field(validation_alias="context:application:id")
13
+
14
+
15
+
16
+ if __name__ == "__main__":
17
+ user = ContentGridUser(
18
+ **{
19
+ "sub": "https://auth.sandbox.contentgrid.cloud/realms/cg-77594d8b-9bc9-40ed-b2a8-9c03a2905a20#d91e9a4d-8447-4bf1-8d10-339b2e5951ea",
20
+ "aud": "contentgrid:extension:extract",
21
+ "restrict:principal_claims": "24ZPEGV0IS5MAF8C2BjmaqH1p7wL4YS409zlL8ZE+nEUHsFFDu80eDpJXoFvZIb1Hh9bxamGaK0gE14wvA+btCuDrg5lkGcdCVj3zm/RWnIFKzlGUVn7Zkj4z4PCzsq/itKVNXEYBtAS/d0NRFSiZGvy775kFdK1VOi+hxsic1bHAZTvSs1jEFuddxEULExh2MqZ5h43n/vEhB0sxkmXevR7XSE4iolDzCWGrw6HzUZYP/QlSlz/S3cK+aeoShAP1G2SbuTGub5h1fsKMM22eg==",
22
+ "iss": "https://extensions.sandbox.contentgrid.cloud/authentication/external",
23
+ "may_act": {
24
+ "sub": "extract",
25
+ "iss": "https://auth.sandbox.contentgrid.cloud/realms/extensions"
26
+ },
27
+ "context:application:domains": [
28
+ "8be240cc-4581-43c2-96db-a8ccf8579e7d.sandbox.contentgrid.cloud"
29
+ ],
30
+ "exp": 1755775732,
31
+ "context:application:id": "8be240cc-4581-43c2-96db-a8ccf8579e7d",
32
+ "access_token" : "123"
33
+ }
34
+ )
35
+ print(user)
@@ -0,0 +1,40 @@
1
+
2
+ from pydantic_settings import BaseSettings, SettingsConfigDict
3
+ from typing import List, Optional
4
+
5
+ # Pydantic settings is used throughout this file.
6
+ # The library allows for easy configuration management, including environment variable loading and validation.
7
+ # Each field can be configured to load from environment variables, and validation can be applied to ensure correct types and formats.
8
+ # Example:
9
+ # class MyConfig(BaseSettings):
10
+ # my_field: str = "default_value"
11
+ # my_required_bool: bool # This field must be provided
12
+ # my_optional_field: Optional[int] = None
13
+ #
14
+ # ENVIRONMENT VARIABLES:
15
+ # MY_FIELD=my_value
16
+ # MY_REQUIRED_BOOL=t
17
+ # MY_OPTIONAL_FIELD=42
18
+ # See https://docs.pydantic.dev/latest/api/pydantic_settings/ for more details.
19
+
20
+ class ExtensionConfig(BaseSettings):
21
+ model_config = SettingsConfigDict(
22
+ env_file=['.env', '.env.secret'],
23
+ env_file_encoding='utf-8',
24
+ extra='ignore',
25
+ )
26
+
27
+ cors_origins : List[str] = [] # is taken into account if production is True
28
+
29
+ # Server Configuration
30
+ server_url: Optional[str] = None # Base URL for the server, can be set to None for local development
31
+ server_host: Optional[str] = ""
32
+ server_port: Optional[int] = None
33
+ web_concurrency: Optional[int] = None
34
+
35
+ # Environment Configuration
36
+ ci: bool = False
37
+ production: bool = False
38
+
39
+
40
+ extension_config = ExtensionConfig()
@@ -0,0 +1,73 @@
1
+ from typing import Optional, cast, TypeVar, Generic
2
+ from fastapi import Depends
3
+ from typing_extensions import Annotated
4
+ from pydantic import BaseModel
5
+ from pydantic_settings import BaseSettings
6
+ from contentgrid_extension_helpers.authentication.user import ContentGridUser
7
+ from contentgrid_extension_helpers.authentication.oidc import create_current_user_dependency, get_oauth_jwks_client, create_oauth2_scheme
8
+
9
+ oauth2_scheme = create_oauth2_scheme()
10
+
11
+ UserModelType = TypeVar('UserModelType', bound=BaseModel)
12
+
13
+ class ContentGridUserConfig(BaseSettings):
14
+ extension_name : str # Should be the same extension name as defined in tokenmonger and keycloak (without the contentgrid:extension: prefix)
15
+ oauth_issuer: str
16
+ extension_auth_url: str
17
+
18
+ class ContentGridUserDependency(Generic[UserModelType]):
19
+ def __init__(
20
+ self,
21
+ extension_name: Optional[str] = None,
22
+ oauth_issuer: Optional[str] = None,
23
+ custom_audience: Optional[str] = None,
24
+ user_model: type[UserModelType] = ContentGridUser,
25
+ algorithms: Optional[list[str]] = None,
26
+ verify_exp: bool = True,
27
+ verify_aud: bool = True,
28
+ verify_iss: bool = True,
29
+ verify_nbf: bool = False,
30
+ verify_iat: bool = False,
31
+ ) -> None:
32
+
33
+ self.user_model = user_model
34
+
35
+ # Create config dict with provided parameters
36
+ config_dict = {}
37
+ if extension_name is not None:
38
+ config_dict['extension_name'] = extension_name
39
+ if oauth_issuer is not None:
40
+ config_dict['oauth_issuer'] = oauth_issuer
41
+
42
+ # Create ExtensionFlowConfig instance which will use env vars for missing values
43
+ self.user_config = ContentGridUserConfig(**config_dict)
44
+
45
+ if not custom_audience:
46
+ self.audience = f"contentgrid:extension:{self.user_config.extension_name}"
47
+ else:
48
+ self.audience = custom_audience
49
+
50
+ _ , self.jwks_client = get_oauth_jwks_client(self.user_config.oauth_issuer)
51
+
52
+ self.user_dependency = create_current_user_dependency(
53
+ jwks_client=self.jwks_client,
54
+ oidc_issuer=self.user_config.oauth_issuer,
55
+ audience=self.audience,
56
+ user_model=user_model,
57
+ algorithms = algorithms,
58
+ verify_exp = verify_exp,
59
+ verify_aud = verify_aud,
60
+ verify_iss = verify_iss,
61
+ verify_nbf = verify_nbf,
62
+ verify_iat = verify_iat,
63
+ )
64
+
65
+ @property
66
+ def config(self) -> ContentGridUserConfig:
67
+ return self.user_config
68
+
69
+ async def __call__(
70
+ self, token: Annotated[str, Depends(oauth2_scheme)]
71
+ ) -> UserModelType:
72
+ user = cast(UserModelType, await self.user_dependency(token))
73
+ return user
@@ -0,0 +1,3 @@
1
+ from .client_factory import ContentGridBaseClientFactory # noqa: F401
2
+ from .service_account_factory import ContentGridServiceAccountFactory # noqa: F401
3
+ from .extension_flow_factory import ContentGridExtensionFlowClientFactory # noqa: F401
@@ -0,0 +1,71 @@
1
+ from abc import ABC
2
+ from typing import Optional, TypeVar, List
3
+ from contentgrid_hal_client.hal import HALFormsClient
4
+ from fastapi import HTTPException, status
5
+ from pydantic import HttpUrl
6
+ from pydantic_settings import BaseSettings
7
+ import logging
8
+
9
+
10
+ T = TypeVar('T', bound=HALFormsClient)
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class ContentGridClientFactorySettings(BaseSettings):
15
+ pass
16
+
17
+ class ContentGridBaseClientFactory(ABC):
18
+ def __init__(self) -> None:
19
+ self.env_config = ContentGridClientFactorySettings()
20
+
21
+ def _get_client_endpoint(
22
+ self,
23
+ origin: Optional[HttpUrl] = None,
24
+ allowed_domains: Optional[List[str]] = None,
25
+ default_endpoint: Optional[str] = None
26
+ ) -> str:
27
+ """
28
+ Get client endpoint with domain validation.
29
+
30
+ Args:
31
+ origin: Origin URL to validate and use
32
+ allowed_domains: List of allowed domain strings for validation
33
+ default_endpoint: Default endpoint to use if no origin provided
34
+ user: ContentGrid user (optional, used for domain extraction if no allowed_domains)
35
+
36
+ Returns:
37
+ Validated client endpoint URL
38
+
39
+ Raises:
40
+ ValueError: If domain validation fails in production mode
41
+ """
42
+ if allowed_domains is None:
43
+ # No domains available, will rely on default_endpoint or production check
44
+ allowed_domains = []
45
+
46
+ if origin:
47
+ origin_host = origin.host
48
+ origin_scheme = origin.scheme
49
+ endpoint = f"{origin_scheme}://{origin_host}"
50
+
51
+ # Validate domain if we have allowed domains defined (even if empty)
52
+ if origin_host not in allowed_domains:
53
+ error_msg = f"Origin domain '{origin_host}' not in allowed domains: {allowed_domains}"
54
+ raise HTTPException(
55
+ status_code=status.HTTP_403_FORBIDDEN,
56
+ detail=error_msg,
57
+ )
58
+
59
+ return endpoint
60
+
61
+ elif default_endpoint:
62
+ # Use default endpoint
63
+ return default_endpoint
64
+
65
+ elif allowed_domains:
66
+ # Fallback to the first allowed domain
67
+ return "https://" + allowed_domains[0]
68
+
69
+ else:
70
+ error_msg = "No endpoint available: provide either 'origin', 'default_endpoint', or ensure 'allowed_domains'/'user.domains' are available"
71
+ raise ValueError(error_msg)
@@ -0,0 +1,85 @@
1
+ from typing import Annotated, Optional, Type, TypeVar
2
+
3
+ from contentgrid_extension_helpers.authentication.user import ContentGridUser
4
+ from fastapi import Depends, Query
5
+ from contentgrid_extension_helpers.authentication.oidc import (
6
+ create_oauth2_scheme,
7
+ )
8
+ from contentgrid_application_client.application import ContentGridApplicationClient
9
+ from contentgrid_hal_client.hal import HALFormsClient
10
+ from contentgrid_hal_client.security import IdentityAuthenticationManager
11
+ from pydantic import HttpUrl
12
+ from contentgrid_extension_helpers.dependencies.authentication.user import ContentGridUserDependency
13
+ from .client_factory import ContentGridBaseClientFactory, ContentGridClientFactorySettings
14
+
15
+ T = TypeVar('T', ContentGridApplicationClient, HALFormsClient)
16
+
17
+ oauth2_scheme = create_oauth2_scheme()
18
+
19
+ class ExtensionFlowConfig(ContentGridClientFactorySettings):
20
+ extension_client_name: str
21
+ extension_client_secret: str
22
+ system_exchange_uri: str = "https://extensions.sandbox.contentgrid.cloud/authentication/system/token"
23
+ extension_auth_url: str = "https://auth.sandbox.contentgrid.cloud/realms/extensions/protocol/openid-connect/token"
24
+ delegated_exchange_uri: str = "https://extensions.sandbox.contentgrid.cloud/authentication/delegated/token"
25
+
26
+ class ContentGridExtensionFlowClientFactory(ContentGridBaseClientFactory):
27
+ def __init__(
28
+ self,
29
+ extension_auth_url: Optional[str] = None,
30
+ extension_client_name: Optional[str] = None,
31
+ extension_client_secret: Optional[str] = None,
32
+ system_exchange_uri: Optional[str] = None,
33
+ delegated_exchange_uri: Optional[str] = None,
34
+ ) -> None:
35
+ # Create config dict with provided parameters
36
+ config_dict = {}
37
+ if extension_auth_url is not None:
38
+ config_dict['extension_auth_url'] = extension_auth_url
39
+ if extension_client_name is not None:
40
+ config_dict['extension_client_name'] = extension_client_name
41
+ if extension_client_secret is not None:
42
+ config_dict['extension_client_secret'] = extension_client_secret
43
+ if system_exchange_uri is not None:
44
+ config_dict['system_exchange_uri'] = system_exchange_uri
45
+ if delegated_exchange_uri is not None:
46
+ config_dict['delegated_exchange_uri'] = delegated_exchange_uri
47
+
48
+ # Create ExtensionFlowConfig instance which will use env vars for missing values
49
+ self.extension_config = ExtensionFlowConfig(**config_dict)
50
+
51
+ self.identity_auth_manager = IdentityAuthenticationManager(
52
+ auth_uri=self.extension_config.extension_auth_url,
53
+ client_id=self.extension_config.extension_client_name,
54
+ client_secret=self.extension_config.extension_client_secret,
55
+ system_exchange_uri=self.extension_config.system_exchange_uri,
56
+ delegated_exchange_uri=self.extension_config.delegated_exchange_uri,
57
+ )
58
+ super().__init__()
59
+
60
+ @property
61
+ def config(self) -> ExtensionFlowConfig:
62
+ return self.extension_config
63
+
64
+ def get_client(self, user: ContentGridUser, origin: Optional[HttpUrl], client_type: Type[T] = ContentGridApplicationClient) -> T:
65
+ """Get a client of the specified type."""
66
+ client_endpoint = self._get_client_endpoint(
67
+ origin=origin,
68
+ allowed_domains=user.domains
69
+ )
70
+ auth_manager = self.identity_auth_manager.for_user(user.access_token, urls={client_endpoint})
71
+
72
+ return client_type(
73
+ client_endpoint=client_endpoint,
74
+ auth_manager=auth_manager,
75
+ )
76
+
77
+ def create_client_dependency(self, user_dependency: ContentGridUserDependency, client_type: Type[T] = ContentGridApplicationClient):
78
+ """Create a dependency function for the specified client type."""
79
+ def client_dependency(
80
+ user: Annotated[ContentGridUser, Depends(user_dependency)],
81
+ origin: Annotated[Optional[HttpUrl], Query()] = None
82
+ ) -> T:
83
+ return self.get_client(user, origin, client_type)
84
+
85
+ return client_dependency
@@ -0,0 +1,87 @@
1
+ from typing import Annotated, Optional, Type, TypeVar, List
2
+ from fastapi import Query
3
+ from contentgrid_application_client.application import ContentGridApplicationClient
4
+ from contentgrid_hal_client.hal import HALFormsClient
5
+ from contentgrid_hal_client.security import ClientCredentialsApplicationAuthenticationManager
6
+ from pydantic import Field, HttpUrl
7
+ from pydantic_settings import SettingsConfigDict
8
+ from .client_factory import ContentGridBaseClientFactory, ContentGridClientFactorySettings
9
+
10
+ T = TypeVar('T', ContentGridApplicationClient, HALFormsClient)
11
+
12
+ class ContentGridServiceAccountFactorySettings(ContentGridClientFactorySettings):
13
+ model_config = SettingsConfigDict(env_prefix='CG_')
14
+
15
+ auth_url: str
16
+ client_name: str
17
+ client_secret: str
18
+ # Default endpoint for service account, can be overridden
19
+ default_endpoint: Optional[str] = Field(default=None, alias='CG_APP_URL')
20
+ # Allowed domains for service account access
21
+ allowed_domains: Optional[List[str]] = None
22
+
23
+ class ContentGridServiceAccountFactory(ContentGridBaseClientFactory):
24
+ def __init__(
25
+ self,
26
+ auth_url: Optional[str] = None,
27
+ client_name: Optional[str] = None,
28
+ client_secret: Optional[str] = None,
29
+ default_endpoint: Optional[str] = None,
30
+ allowed_domains: Optional[List[str]] = None,
31
+ ) -> None:
32
+ # Create config dict with provided parameters
33
+ config_dict = {}
34
+ if auth_url is not None:
35
+ config_dict['auth_url'] = auth_url
36
+ if client_name is not None:
37
+ config_dict['client_name'] = client_name
38
+ if client_secret is not None:
39
+ config_dict['client_secret'] = client_secret
40
+ if default_endpoint is not None:
41
+ config_dict['default_endpoint'] = default_endpoint
42
+ if allowed_domains is not None:
43
+ config_dict['allowed_domains'] = allowed_domains
44
+
45
+ # Create ServiceAccountFactorySettings instance which will use env vars for missing values
46
+ self.service_account_config = ContentGridServiceAccountFactorySettings(**config_dict)
47
+
48
+ self.authentication_manager = ClientCredentialsApplicationAuthenticationManager(
49
+ auth_uri=self.service_account_config.auth_url,
50
+ client_id=self.service_account_config.client_name,
51
+ client_secret=self.service_account_config.client_secret,
52
+ )
53
+ super().__init__()
54
+
55
+ @property
56
+ def config(self) -> ContentGridServiceAccountFactorySettings:
57
+ return self.service_account_config
58
+
59
+ def get_client(self, origin: Optional[HttpUrl] = None, client_type: Type[T] = ContentGridApplicationClient) -> T:
60
+ """Get a client of the specified type using service account authentication."""
61
+ client_endpoint = self._get_client_endpoint(
62
+ origin=origin,
63
+ allowed_domains=self.service_account_config.allowed_domains,
64
+ default_endpoint=self.service_account_config.default_endpoint
65
+ )
66
+
67
+ # Create a new authentication manager instance for this specific endpoint
68
+ auth_manager = ClientCredentialsApplicationAuthenticationManager(
69
+ auth_uri=self.service_account_config.auth_url,
70
+ client_id=self.service_account_config.client_name,
71
+ client_secret=self.service_account_config.client_secret,
72
+ resources=[client_endpoint]
73
+ )
74
+
75
+ return client_type(
76
+ client_endpoint=client_endpoint,
77
+ auth_manager=auth_manager,
78
+ )
79
+
80
+ def create_client_dependency(self, client_type: Type[T] = ContentGridApplicationClient):
81
+ """Create a dependency function for the specified client type (no user required)."""
82
+ def client_dependency(
83
+ origin: Annotated[Optional[HttpUrl], Query()] = None
84
+ ) -> T:
85
+ return self.get_client(origin, client_type)
86
+
87
+ return client_dependency
@@ -0,0 +1,14 @@
1
+ """Database factory classes for different database backends."""
2
+
3
+ from .base_factory import DatabaseConfig, DatabaseSessionFactory
4
+ from .sqlite import SQLiteConfig, SQLiteSessionFactory
5
+ from .postgres import PostgresConfig, PostgresSessionFactory
6
+
7
+ __all__ = [
8
+ "DatabaseConfig",
9
+ "DatabaseSessionFactory",
10
+ "SQLiteConfig",
11
+ "SQLiteSessionFactory",
12
+ "PostgresConfig",
13
+ "PostgresSessionFactory"
14
+ ]
@@ -0,0 +1,107 @@
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from contextlib import contextmanager
4
+ from typing import Optional, Dict, Any, Generator
5
+ from sqlmodel import SQLModel, Session, create_engine, text
6
+ from pydantic_settings import BaseSettings
7
+
8
+ class DatabaseConfig(BaseSettings):
9
+ debug: bool = False
10
+
11
+ class DatabaseSessionFactory(ABC):
12
+ """Abstract factory class to create database connections based on configuration."""
13
+
14
+ def __init__(self, config : Optional[DatabaseConfig] = None):
15
+ """Initialize with database configuration."""
16
+ self.db_config = config if config else DatabaseConfig()
17
+
18
+ # Get values from abstract methods and validate them
19
+ connection_string = self.create_connection_string()
20
+ connect_args = self.create_connect_args()
21
+ engine_kwargs = self.create_engine_kwargs()
22
+
23
+ # Assert that required values are provided
24
+ assert connection_string is not None and connection_string.strip(), "Connection string must be a non-empty string"
25
+ assert isinstance(connect_args, dict), "Connect args must be a dictionary"
26
+ assert isinstance(engine_kwargs, dict), "Engine kwargs must be a dictionary"
27
+
28
+ self.engine = create_engine(
29
+ connection_string,
30
+ connect_args=connect_args,
31
+ echo="debug" if self.db_config.debug else None, # Log SQL queries in debug mode
32
+ **engine_kwargs
33
+ )
34
+
35
+ @abstractmethod
36
+ def create_connection_string(self) -> str:
37
+ """Create the database connection string."""
38
+ pass
39
+
40
+ @abstractmethod
41
+ def create_connect_args(self) -> Dict[str, Any]:
42
+ """Create connection arguments."""
43
+ pass
44
+
45
+ @abstractmethod
46
+ def create_engine_kwargs(self) -> Dict[str, Any]:
47
+ """Create engine keyword arguments."""
48
+ pass
49
+
50
+ def create_db_and_tables(self) -> None:
51
+ """Create database tables from SQLModel metadata."""
52
+ SQLModel.metadata.create_all(self.engine)
53
+
54
+ # IMPORTANT : NO AUTOCOMMITS
55
+ # SQLAlchemy session management
56
+ # This function is used for dependency injection in FastAPI
57
+ # It provides a session that is NOT automatically committed.
58
+ # It is the responsibility of the caller to commit one or more transactions.
59
+ # When an error occurs the session is rolledback.
60
+ def __call__(self) -> Generator[Session, None, None]:
61
+ """Get a database session for dependency injection."""
62
+ session = Session(self.engine)
63
+ try:
64
+ yield session
65
+ except Exception as e:
66
+ logging.exception(f"Database session error - Unexpected error: {e}")
67
+ session.rollback()
68
+ raise
69
+ finally:
70
+ session.close()
71
+
72
+ # Context managers can no be used for dependency injection in FastAPI
73
+ # but they are useful for manual session management in scripts or tests.
74
+ @contextmanager
75
+ def get_db_session(self) -> Generator[Session, None, None]:
76
+ """Context manager for database sessions."""
77
+ session = Session(self.engine)
78
+ try:
79
+ yield session
80
+ session.commit()
81
+ except Exception as e:
82
+ logging.exception(f"Database transaction error - Unexpected error: {e}")
83
+ session.rollback()
84
+ raise
85
+ finally:
86
+ session.close()
87
+
88
+ def database_health_check(self) -> bool:
89
+ """Check if database connection is healthy."""
90
+ try:
91
+ with self.get_db_session() as session:
92
+ # Simple query to test connection
93
+ session.exec(text("SELECT 1"))
94
+ return True
95
+ except Exception as e:
96
+ logging.exception(f"Database health check failed - Unexpected error: {e}")
97
+ return False
98
+
99
+ def wipe_database(self) -> None:
100
+ """Wipes the database by dropping all tables and recreating them."""
101
+ from sqlmodel import SQLModel
102
+ try:
103
+ SQLModel.metadata.drop_all(self.engine)
104
+ SQLModel.metadata.create_all(self.engine)
105
+ logging.debug("Database tables dropped and recreated successfully")
106
+ except Exception as e:
107
+ logging.warning(f"Database cleanup failed: {e}")
@@ -0,0 +1,104 @@
1
+
2
+ from typing import Optional, Dict, Any
3
+ from .base_factory import DatabaseConfig, DatabaseSessionFactory
4
+
5
+ class PostgresConfig(DatabaseConfig):
6
+ pg_host: Optional[str] = None
7
+ pg_port: Optional[int] = None
8
+ pg_user: Optional[str] = None
9
+ pg_passwd: Optional[str] = None
10
+ pg_dbname: Optional[str] = None
11
+
12
+ # Database Connection Pool Configuration
13
+ db_pool_size: int = 10
14
+ db_max_overflow: int = 20
15
+ db_pool_recycle: int = 3600
16
+ db_pool_pre_ping: bool = True
17
+
18
+ # Pydantic settings is used throughout this file.
19
+ # The library allows for easy configuration management, including environment variable loading and validation.
20
+ # Each field can be configured to load from environment variables, and validation can be applied to ensure correct types and formats.
21
+ # Example:
22
+ # class MyConfig(BaseSettings):
23
+ # my_field: str = "default_value"
24
+ # my_required_bool: bool # This field must be provided
25
+ # my_optional_field: Optional[int] = None
26
+ #
27
+ # ENVIRONMENT VARIABLES:
28
+ # MY_FIELD=my_value
29
+ # MY_REQUIRED_BOOL=t
30
+ # MY_OPTIONAL_FIELD=42
31
+ # See https://docs.pydantic.dev/latest/api/pydantic_settings/ for more details.
32
+
33
+ class PostgresSessionFactory(DatabaseSessionFactory):
34
+ """Factory class to create PostgreSQL database connections."""
35
+ def __init__(self, pg_host: Optional[str] = None, pg_port: Optional[int] = None,
36
+ pg_user: Optional[str] = None, pg_passwd: Optional[str] = None,
37
+ pg_dbname: Optional[str] = None, debug: Optional[bool] = None,
38
+ db_pool_size: Optional[int] = None, db_max_overflow: Optional[int] = None,
39
+ db_pool_recycle: Optional[int] = None, db_pool_pre_ping: Optional[bool] = None):
40
+ """Initialize with PostgreSQL configuration."""
41
+ # Create config dict with provided parameters
42
+ config_dict = {}
43
+ if debug is not None:
44
+ config_dict['debug'] = debug
45
+ if db_pool_size is not None:
46
+ config_dict['db_pool_size'] = db_pool_size
47
+ if db_max_overflow is not None:
48
+ config_dict['db_max_overflow'] = db_max_overflow
49
+ if db_pool_recycle is not None:
50
+ config_dict['db_pool_recycle'] = db_pool_recycle
51
+ if db_pool_pre_ping is not None:
52
+ config_dict['db_pool_pre_ping'] = db_pool_pre_ping
53
+
54
+ # Override with explicit parameters if provided
55
+ if pg_host is not None:
56
+ config_dict['pg_host'] = pg_host
57
+ if pg_port is not None:
58
+ config_dict['pg_port'] = pg_port
59
+ if pg_user is not None:
60
+ config_dict['pg_user'] = pg_user
61
+ if pg_passwd is not None:
62
+ config_dict['pg_passwd'] = pg_passwd
63
+ if pg_dbname is not None:
64
+ config_dict['pg_dbname'] = pg_dbname
65
+
66
+ # Create PostgresConfig instance which will use env vars for missing values
67
+ self.postgres_config = PostgresConfig(**config_dict)
68
+
69
+ # Validate required fields
70
+ missing_fields = []
71
+ if not self.postgres_config.pg_host:
72
+ missing_fields.append("PG_HOST")
73
+ if not self.postgres_config.pg_port:
74
+ missing_fields.append("PG_PORT")
75
+ if not self.postgres_config.pg_user:
76
+ missing_fields.append("PG_USER")
77
+ if not self.postgres_config.pg_passwd:
78
+ missing_fields.append("PG_PASSWD")
79
+ if not self.postgres_config.pg_dbname:
80
+ missing_fields.append("PG_DBNAME")
81
+
82
+ if missing_fields:
83
+ raise ValueError(
84
+ f"Failed to configure postgres. Missing parameters or environment variables: {', '.join(missing_fields)}"
85
+ )
86
+
87
+ super().__init__(self.postgres_config)
88
+
89
+ def create_connection_string(self) -> str:
90
+ """Create the PostgreSQL connection string."""
91
+ return f"postgresql+psycopg2://{self.postgres_config.pg_user}:{self.postgres_config.pg_passwd}@{self.postgres_config.pg_host}:{self.postgres_config.pg_port}/{self.postgres_config.pg_dbname}"
92
+
93
+ def create_connect_args(self) -> Dict[str, Any]:
94
+ """Create PostgreSQL connection arguments."""
95
+ return {}
96
+
97
+ def create_engine_kwargs(self) -> Dict[str, Any]:
98
+ """Create PostgreSQL engine keyword arguments."""
99
+ return {
100
+ "pool_size": self.postgres_config.db_pool_size,
101
+ "max_overflow": self.postgres_config.db_max_overflow,
102
+ "pool_pre_ping": self.postgres_config.db_pool_pre_ping,
103
+ "pool_recycle": self.postgres_config.db_pool_recycle
104
+ }