maleo-database 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,12 @@
1
+ from pydantic import BaseModel, Field
2
+ from maleo.enums.environment import Environment
3
+ from maleo.types.base.dict import OptionalStringToStringDict
4
+ from maleo.types.base.string import OptionalString
5
+
6
+
7
+ class DatabaseIdentifierConfig(BaseModel):
8
+ enabled: bool = Field(True, description="Whether the database is enabled")
9
+ environment: Environment = Field(..., description="Database's environment")
10
+ name: str = Field(..., description="Database's name")
11
+ description: OptionalString = Field(None, description="Database's description")
12
+ tags: OptionalStringToStringDict = Field(None, description="Database's tags")
@@ -0,0 +1,255 @@
1
+ from pydantic import BaseModel, ConfigDict, Field, model_validator
2
+ from typing import Self, TypeVar
3
+ from maleo.types.base.dict import OptionalStringToStringDict
4
+ from maleo.types.base.integer import ListOfIntegers
5
+ from maleo.types.base.string import OptionalString
6
+ from maleo.utils.formatters.case import to_camel
7
+ from ..enums import PoolingStrategy
8
+
9
+
10
+ class BasePoolingConfig(BaseModel):
11
+ """Base configuration class for database connection pooling."""
12
+
13
+
14
+ PoolingConfigT = TypeVar("PoolingConfigT", bound=BasePoolingConfig)
15
+
16
+
17
+ class PostgreSQLPoolingConfig(BasePoolingConfig):
18
+ """PostgreSQL-specific pooling configuration."""
19
+
20
+ pool_size: int = Field(
21
+ default=10, ge=1, le=1000, description="Number of connections in the pool"
22
+ )
23
+ max_overflow: int = Field(
24
+ default=20, ge=0, le=500, description="Maximum number of overflow connections"
25
+ )
26
+ pool_timeout: float = Field(
27
+ default=30.0,
28
+ ge=1.0,
29
+ le=300.0,
30
+ description="Timeout in seconds for getting connection",
31
+ )
32
+ pool_recycle: int = Field(
33
+ default=3600, ge=60, le=86400, description="Connection recycle time in seconds"
34
+ )
35
+ pool_pre_ping: bool = Field(
36
+ default=True, description="Validate connections before use"
37
+ )
38
+ # Keep strategy and prepared_statement_cache_size as they're pooling-related
39
+ strategy: PoolingStrategy = Field(
40
+ default=PoolingStrategy.DYNAMIC, description="Pooling strategy"
41
+ )
42
+ prepared_statement_cache_size: int = Field(
43
+ default=100, ge=0, le=10000, description="Prepared statement cache size"
44
+ )
45
+ pool_reset_on_return: bool = Field(
46
+ default=True, description="Reset connection state on return to pool"
47
+ )
48
+
49
+ @model_validator(mode="after")
50
+ def validate_overflow(self) -> Self:
51
+ if self.max_overflow > self.pool_size * 5:
52
+ raise ValueError("max_overflow should not exceed 5x pool_size")
53
+ return self
54
+
55
+
56
+ class MySQLPoolingConfig(BasePoolingConfig):
57
+ """MySQL-specific pooling configuration."""
58
+
59
+ pool_size: int = Field(
60
+ default=8, ge=1, le=500, description="Number of connections in the pool"
61
+ )
62
+ max_overflow: int = Field(
63
+ default=15, ge=0, le=200, description="Maximum number of overflow connections"
64
+ )
65
+ pool_timeout: float = Field(
66
+ default=20.0,
67
+ ge=1.0,
68
+ le=300.0,
69
+ description="Timeout in seconds for getting connection",
70
+ )
71
+ pool_recycle: int = Field(
72
+ default=7200, ge=60, le=86400, description="Connection recycle time in seconds"
73
+ )
74
+ pool_pre_ping: bool = Field(
75
+ default=True, description="Validate connections before use"
76
+ )
77
+ strategy: PoolingStrategy = Field(
78
+ default=PoolingStrategy.FIXED, description="Pooling strategy"
79
+ )
80
+ # Add autocommit to pooling since it affects connection behavior in the pool
81
+ autocommit: bool = Field(default=False, description="Enable autocommit mode")
82
+ # Move connect_timeout here since it's about pool connection establishment
83
+ connect_timeout: float = Field(
84
+ default=10.0, ge=1.0, le=60.0, description="Connection timeout in seconds"
85
+ )
86
+
87
+
88
+ class SQLitePoolingConfig(BasePoolingConfig):
89
+ """SQLite-specific pooling configuration."""
90
+
91
+ pool_size: int = Field(
92
+ default=1, ge=1, le=10, description="Number of connections (limited for SQLite)"
93
+ )
94
+ max_overflow: int = Field(
95
+ default=5, ge=0, le=20, description="Maximum overflow connections"
96
+ )
97
+ pool_timeout: float = Field(
98
+ default=30.0, ge=1.0, le=300.0, description="Timeout in seconds"
99
+ )
100
+ # SQLite-specific pooling options
101
+ wal_mode: bool = Field(
102
+ default=True, description="Enable WAL mode for better concurrency"
103
+ )
104
+ busy_timeout: int = Field(
105
+ default=30000, ge=1000, le=300000, description="Busy timeout in milliseconds"
106
+ )
107
+
108
+
109
+ class SQLServerPoolingConfig(BasePoolingConfig):
110
+ """SQL Server-specific pooling configuration."""
111
+
112
+ pool_size: int = Field(
113
+ default=10, ge=1, le=500, description="Number of connections in the pool"
114
+ )
115
+ max_overflow: int = Field(
116
+ default=20, ge=0, le=200, description="Maximum number of overflow connections"
117
+ )
118
+ pool_timeout: float = Field(
119
+ default=30.0,
120
+ ge=1.0,
121
+ le=300.0,
122
+ description="Timeout in seconds for getting connection",
123
+ )
124
+ pool_recycle: int = Field(
125
+ default=3600, ge=60, le=86400, description="Connection recycle time in seconds"
126
+ )
127
+ pool_pre_ping: bool = Field(
128
+ default=True, description="Validate connections before use"
129
+ )
130
+ strategy: PoolingStrategy = Field(
131
+ default=PoolingStrategy.DYNAMIC, description="Pooling strategy"
132
+ )
133
+ # SQL Server-specific pooling settings
134
+ connection_timeout: int = Field(
135
+ default=30, ge=1, le=300, description="Connection timeout in seconds"
136
+ )
137
+ command_timeout: int = Field(
138
+ default=30, ge=1, le=3600, description="Command timeout in seconds"
139
+ )
140
+ packet_size: int = Field(
141
+ default=4096, ge=512, le=32767, description="Network packet size"
142
+ )
143
+ trust_server_certificate: bool = Field(
144
+ default=False, description="Trust server certificate"
145
+ )
146
+ # Move encrypt here since it affects connection pool behavior
147
+ encrypt: bool = Field(default=True, description="Encrypt connection")
148
+
149
+ @model_validator(mode="after")
150
+ def validate_overflow(self) -> Self:
151
+ if self.max_overflow > self.pool_size * 3:
152
+ raise ValueError("max_overflow should not exceed 3x pool_size")
153
+ return self
154
+
155
+
156
+ class MongoDBPoolingConfig(BasePoolingConfig):
157
+ """MongoDB-specific pooling configuration."""
158
+
159
+ model_config = ConfigDict(alias_generator=to_camel)
160
+
161
+ max_pool_size: int = Field(
162
+ default=100, ge=1, le=500, description="Maximum number of connections in pool"
163
+ )
164
+ min_pool_size: int = Field(
165
+ default=0, ge=0, le=100, description="Minimum number of connections in pool"
166
+ )
167
+ max_idle_time_ms: int = Field(
168
+ default=600000, ge=1000, le=3600000, description="Max idle time in milliseconds"
169
+ )
170
+ connect_timeout_ms: int = Field(
171
+ default=20000,
172
+ ge=1000,
173
+ le=300000,
174
+ description="Connection timeout in milliseconds",
175
+ )
176
+ server_selection_timeout_ms: int = Field(
177
+ default=30000, ge=1000, le=300000, description="Server selection timeout"
178
+ )
179
+ max_connecting: int = Field(
180
+ default=2,
181
+ ge=1,
182
+ le=10,
183
+ description="Maximum number of concurrent connection attempts",
184
+ )
185
+
186
+
187
+ class RedisPoolingConfig(BasePoolingConfig):
188
+ """Redis-specific pooling configuration."""
189
+
190
+ max_connections: int = Field(
191
+ default=50, ge=1, le=1000, description="Maximum number of connections in pool"
192
+ )
193
+ retry_on_timeout: bool = Field(
194
+ default=True, description="Retry on connection timeout"
195
+ )
196
+ health_check_interval: int = Field(
197
+ default=30, ge=5, le=300, description="Health check interval in seconds"
198
+ )
199
+ connection_timeout: float = Field(
200
+ default=5.0, ge=1.0, le=60.0, description="Connection timeout in seconds"
201
+ )
202
+ socket_timeout: float = Field(
203
+ default=5.0, ge=1.0, le=60.0, description="Socket timeout in seconds"
204
+ )
205
+ socket_keepalive: bool = Field(default=True, description="Enable TCP keepalive")
206
+ decode_responses: bool = Field(
207
+ default=True, description="Decode responses to strings"
208
+ )
209
+
210
+
211
+ class ElasticsearchPoolingConfig(BasePoolingConfig):
212
+ """Elasticsearch-specific pooling configuration."""
213
+
214
+ # Connection pool settings
215
+ maxsize: int = Field(
216
+ default=25, ge=1, le=100, description="Maximum number of connections in pool"
217
+ )
218
+ connections_per_node: int = Field(
219
+ default=10, ge=1, le=50, description="Connections per Elasticsearch node"
220
+ )
221
+
222
+ # Timeout settings
223
+ timeout: float = Field(
224
+ default=10.0, ge=1.0, le=300.0, description="Request timeout in seconds"
225
+ )
226
+ max_retries: int = Field(
227
+ default=3, ge=0, le=10, description="Maximum number of retries"
228
+ )
229
+ retry_on_timeout: bool = Field(default=False, description="Retry on timeout")
230
+ retry_on_status: ListOfIntegers = Field(
231
+ default_factory=lambda: [502, 503, 504],
232
+ description="HTTP status codes to retry on",
233
+ )
234
+
235
+ # Connection behavior (move from connection config)
236
+ http_compress: bool = Field(default=True, description="Enable HTTP compression")
237
+ verify_certs: bool = Field(default=True, description="Verify SSL certificates")
238
+ ca_certs: OptionalString = Field(
239
+ default=None, description="Path to CA certificates"
240
+ )
241
+
242
+ # Advanced pool settings
243
+ block: bool = Field(default=False, description="Block when pool is full")
244
+ headers: OptionalStringToStringDict = Field(
245
+ default=None, description="Default headers for requests"
246
+ )
247
+ dead_timeout: float = Field(
248
+ default=60.0, ge=5.0, le=600.0, description="Dead node timeout in seconds"
249
+ )
250
+
251
+ @model_validator(mode="after")
252
+ def validate_overflow(self) -> Self:
253
+ if self.connections_per_node > self.maxsize:
254
+ raise ValueError("connections_per_node must not exceed maxsize")
255
+ return self
@@ -0,0 +1,85 @@
1
+ from enum import StrEnum
2
+
3
+
4
+ class Connection(StrEnum):
5
+ ASYNC = "async"
6
+ SYNC = "sync"
7
+
8
+
9
+ class Driver(StrEnum):
10
+ # SQL Databases - Most Popular
11
+ POSTGRESQL = "postgresql"
12
+ MYSQL = "mysql"
13
+ SQLITE = "sqlite"
14
+
15
+ # SQL Databases - Enterprise
16
+ # ORACLE = "oracle"
17
+ MSSQL = "mssql" # SQL Server
18
+ # MARIADB = "mariadb"
19
+
20
+ # NoSQL Document Stores
21
+ MONGODB = "mongodb"
22
+ # COUCHDB = "couchdb"
23
+
24
+ # NoSQL Key-Value
25
+ REDIS = "redis"
26
+ # DYNAMODB = "dynamodb" # AWS
27
+
28
+ # NoSQL Column Family
29
+ # CASSANDRA = "cassandra"
30
+ # HBASE = "hbase"
31
+
32
+ # NoSQL Graph
33
+ # NEO4J = "neo4j"
34
+ # ARANGODB = "arangodb"
35
+
36
+ # Time Series
37
+ # INFLUXDB = "influxdb"
38
+ # TIMESCALEDB = "timescaledb" # PostgreSQL extension
39
+
40
+ # In-Memory
41
+ # MEMCACHED = "memcached"
42
+
43
+ # Search Engines
44
+ ELASTICSEARCH = "elasticsearch"
45
+ # OPENSEARCH = "opensearch"
46
+
47
+ # Cloud Native
48
+ # FIRESTORE = "firestore" # Google
49
+ # COSMOSDB = "cosmosdb" # Azure
50
+
51
+
52
+ class PostgreSQLSSLMode(StrEnum):
53
+ DISABLE = "disable"
54
+ ALLOW = "allow"
55
+ PREFER = "prefer"
56
+ REQUIRE = "require"
57
+ VERIFY_CA = "verify-ca"
58
+ VERIFY_FULL = "verify-full"
59
+
60
+
61
+ class MySQLCharset(StrEnum):
62
+ UTF8 = "utf8"
63
+ UTF8MB4 = "utf8mb4"
64
+ LATIN1 = "latin1"
65
+ ASCII = "ascii"
66
+
67
+
68
+ class MongoDBReadPreference(StrEnum):
69
+ PRIMARY = "primary"
70
+ PRIMARY_PREFERRED = "primaryPreferred"
71
+ SECONDARY = "secondary"
72
+ SECONDARY_PREFERRED = "secondaryPreferred"
73
+ NEAREST = "nearest"
74
+
75
+
76
+ class ElasticsearchScheme(StrEnum):
77
+ HTTP = "http"
78
+ HTTPS = "https"
79
+
80
+
81
+ class PoolingStrategy(StrEnum):
82
+ FIXED = "fixed"
83
+ DYNAMIC = "dynamic"
84
+ OVERFLOW = "overflow"
85
+ QUEUE = "queue"
File without changes
File without changes
@@ -0,0 +1,66 @@
1
+ from elasticsearch import AsyncElasticsearch, Elasticsearch
2
+ from typing import Literal, Union, overload
3
+ from ...config import ElasticsearchDatabaseConfig
4
+ from ...enums import Connection
5
+
6
+
7
+ class ElasticsearchClientManager:
8
+ def __init__(self, config: ElasticsearchDatabaseConfig) -> None:
9
+ self.config = config
10
+ self._async_client: AsyncElasticsearch = self._init(Connection.ASYNC)
11
+ self._sync_client: Elasticsearch = self._init(Connection.SYNC)
12
+
13
+ @overload
14
+ def _init(self, connection: Literal[Connection.ASYNC]) -> AsyncElasticsearch: ...
15
+ @overload
16
+ def _init(self, connection: Literal[Connection.SYNC]) -> Elasticsearch: ...
17
+ def _init(
18
+ self, connection: Connection = Connection.ASYNC
19
+ ) -> Union[AsyncElasticsearch, Elasticsearch]:
20
+ hosts = [
21
+ {"host": self.config.connection.host, "port": self.config.connection.port}
22
+ ]
23
+
24
+ # Build auth and pooling params properly
25
+ client_kwargs = {}
26
+
27
+ if self.config.connection.username and self.config.connection.password:
28
+ client_kwargs["http_auth"] = (
29
+ self.config.connection.username,
30
+ self.config.connection.password,
31
+ )
32
+
33
+ # Use pooling config directly
34
+ pooling_kwargs = self.config.pooling.model_dump(
35
+ exclude={
36
+ "connections_per_node",
37
+ "block",
38
+ "headers",
39
+ "dead_timeout",
40
+ }, # ES-specific excludes
41
+ exclude_none=True,
42
+ )
43
+ client_kwargs.update(pooling_kwargs)
44
+
45
+ if connection is Connection.ASYNC:
46
+ self._async_client = AsyncElasticsearch(hosts, **client_kwargs)
47
+ return self._async_client
48
+ else:
49
+ self._sync_client = Elasticsearch(hosts, **client_kwargs)
50
+ return self._sync_client
51
+
52
+ @overload
53
+ def get(self, connection: Literal[Connection.ASYNC]) -> AsyncElasticsearch: ...
54
+ @overload
55
+ def get(self, connection: Literal[Connection.SYNC]) -> Elasticsearch: ...
56
+ def get(
57
+ self, connection: Connection = Connection.ASYNC
58
+ ) -> Union[AsyncElasticsearch, Elasticsearch]:
59
+ if connection is Connection.ASYNC:
60
+ return self._async_client or self._init(Connection.ASYNC)
61
+ else:
62
+ return self._sync_client or self._init(Connection.SYNC)
63
+
64
+ async def dispose(self):
65
+ await self._async_client.close()
66
+ self._sync_client.close()
@@ -0,0 +1,53 @@
1
+ from motor.motor_asyncio import AsyncIOMotorClient
2
+ from pymongo import MongoClient
3
+ from typing import Literal, Union, overload
4
+ from ...config import MongoDBDatabaseConfig
5
+ from ...enums import Connection
6
+
7
+
8
+ class MongoDBClientManager:
9
+ def __init__(self, config: MongoDBDatabaseConfig) -> None:
10
+ self.config = config
11
+ self._async_client: AsyncIOMotorClient = self._init(Connection.ASYNC)
12
+ self._sync_client: MongoClient = self._init(Connection.SYNC)
13
+
14
+ @overload
15
+ def _init(self, connection: Literal[Connection.ASYNC]) -> AsyncIOMotorClient: ...
16
+ @overload
17
+ def _init(self, connection: Literal[Connection.SYNC]) -> MongoClient: ...
18
+ def _init(
19
+ self, connection: Connection = Connection.ASYNC
20
+ ) -> Union[AsyncIOMotorClient, MongoClient]:
21
+ url = self.config.connection.make_url(connection)
22
+
23
+ pooling_kwargs = self.config.pooling.model_dump(
24
+ by_alias=True, exclude_none=True
25
+ )
26
+
27
+ if connection is Connection.ASYNC:
28
+ self._async_client = AsyncIOMotorClient(url, **pooling_kwargs)
29
+ return self._async_client
30
+ else:
31
+ self._sync_client = MongoClient(url, **pooling_kwargs)
32
+ return self._sync_client
33
+
34
+ @overload
35
+ def get(self, connection: Literal[Connection.ASYNC]) -> AsyncIOMotorClient: ...
36
+ @overload
37
+ def get(self, connection: Literal[Connection.SYNC]) -> MongoClient: ...
38
+ def get(
39
+ self, connection: Connection = Connection.ASYNC
40
+ ) -> Union[AsyncIOMotorClient, MongoClient]:
41
+ if connection is Connection.ASYNC:
42
+ return self._async_client
43
+ elif connection is Connection.SYNC:
44
+ return self._sync_client
45
+
46
+ def get_database(self, connection: Connection = Connection.ASYNC):
47
+ """Get the specific database object."""
48
+ client = self.get(connection)
49
+ return client[self.config.connection.database]
50
+
51
+ async def dispose(self):
52
+ self._async_client.close()
53
+ self._sync_client.close()
@@ -0,0 +1,57 @@
1
+ from redis.asyncio import Redis as AsyncRedis
2
+ from redis import Redis as SyncRedis
3
+ from typing import Literal, Union, overload
4
+ from ...config import RedisDatabaseConfig
5
+ from ...enums import Connection
6
+
7
+
8
+ class RedisClientManager:
9
+ def __init__(self, config: RedisDatabaseConfig) -> None:
10
+ self.config = config
11
+ self._async_client: AsyncRedis = self._init(Connection.ASYNC)
12
+ self._sync_client: SyncRedis = self._init(Connection.SYNC)
13
+
14
+ @overload
15
+ def _init(self, connection: Literal[Connection.ASYNC]) -> AsyncRedis: ...
16
+ @overload
17
+ def _init(self, connection: Literal[Connection.SYNC]) -> SyncRedis: ...
18
+ def _init(
19
+ self, connection: Connection = Connection.ASYNC
20
+ ) -> Union[AsyncRedis, SyncRedis]:
21
+ url = self.config.connection.make_url(connection)
22
+
23
+ # Redis clients expect different parameter names
24
+ pooling_config = self.config.pooling.model_dump(exclude_none=True)
25
+ redis_kwargs = {
26
+ "max_connections": pooling_config.get("max_connections"),
27
+ "retry_on_timeout": pooling_config.get("retry_on_timeout"),
28
+ "connection_timeout": pooling_config.get("connection_timeout"),
29
+ "socket_timeout": pooling_config.get("socket_timeout"),
30
+ "socket_keepalive": pooling_config.get("socket_keepalive"),
31
+ "decode_responses": pooling_config.get("decode_responses"),
32
+ # health_check_interval doesn't apply to from_url method
33
+ }
34
+ redis_kwargs = {k: v for k, v in redis_kwargs.items() if v is not None}
35
+
36
+ if connection is Connection.ASYNC:
37
+ self._async_client = AsyncRedis.from_url(url, **redis_kwargs)
38
+ return self._async_client
39
+ else:
40
+ self._sync_client = SyncRedis.from_url(url, **redis_kwargs)
41
+ return self._sync_client
42
+
43
+ @overload
44
+ def get(self, connection: Literal[Connection.ASYNC]) -> AsyncRedis: ...
45
+ @overload
46
+ def get(self, connection: Literal[Connection.SYNC]) -> SyncRedis: ...
47
+ def get(
48
+ self, connection: Connection = Connection.ASYNC
49
+ ) -> Union[AsyncRedis, SyncRedis]:
50
+ if connection is Connection.ASYNC:
51
+ return self._async_client
52
+ elif connection is Connection.SYNC:
53
+ return self._sync_client
54
+
55
+ async def dispose(self):
56
+ await self._async_client.close()
57
+ self._sync_client.close()
File without changes
@@ -0,0 +1,56 @@
1
+ from sqlalchemy.engine import create_engine, Engine
2
+ from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
3
+ from typing import Literal, Tuple, Union, overload
4
+ from ...config import MySQLDatabaseConfig
5
+ from ...enums import Connection
6
+
7
+
8
+ class MySQLEngineManager:
9
+ def __init__(self, config: MySQLDatabaseConfig) -> None:
10
+ self.config = config
11
+ self._async_engine: AsyncEngine = self._init(Connection.ASYNC)
12
+ self._sync_engine: Engine = self._init(Connection.SYNC)
13
+
14
+ @overload
15
+ def _init(self, connection: Literal[Connection.ASYNC]) -> AsyncEngine: ...
16
+ @overload
17
+ def _init(self, connection: Literal[Connection.SYNC]) -> Engine: ...
18
+ def _init(
19
+ self, connection: Connection = Connection.ASYNC
20
+ ) -> Union[AsyncEngine, Engine]:
21
+ url = self.config.connection.make_url(connection)
22
+
23
+ pooling_kwargs = self.config.pooling.model_dump(
24
+ exclude={
25
+ "strategy"
26
+ }, # autocommit and connect_timeout are valid SQLAlchemy params
27
+ exclude_none=True,
28
+ )
29
+
30
+ engine_kwargs = {"echo": self.config.connection.echo, **pooling_kwargs}
31
+
32
+ if connection is Connection.ASYNC:
33
+ self._async_engine = create_async_engine(url, **engine_kwargs)
34
+ return self._async_engine
35
+ elif connection is Connection.SYNC:
36
+ self._sync_engine = create_engine(url, **engine_kwargs)
37
+ return self._sync_engine
38
+
39
+ @overload
40
+ def get(self, connection: Literal[Connection.ASYNC]) -> AsyncEngine: ...
41
+ @overload
42
+ def get(self, connection: Literal[Connection.SYNC]) -> Engine: ...
43
+ def get(
44
+ self, connection: Connection = Connection.ASYNC
45
+ ) -> Union[AsyncEngine, Engine]:
46
+ if connection is Connection.ASYNC:
47
+ return self._async_engine
48
+ elif connection is Connection.SYNC:
49
+ return self._sync_engine
50
+
51
+ def get_all(self) -> Tuple[AsyncEngine, Engine]:
52
+ return (self._async_engine, self._sync_engine)
53
+
54
+ async def dispose(self):
55
+ await self._async_engine.dispose()
56
+ self._sync_engine.dispose()
@@ -0,0 +1,58 @@
1
+ from sqlalchemy.engine import create_engine, Engine
2
+ from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
3
+ from typing import Literal, Tuple, Union, overload
4
+ from ...config import PostgreSQLDatabaseConfig
5
+ from ...enums import Connection
6
+
7
+
8
+ class PostgreSQLEngineManager:
9
+ def __init__(self, config: PostgreSQLDatabaseConfig) -> None:
10
+ self.config = config
11
+ self._async_engine: AsyncEngine = self._init(Connection.ASYNC)
12
+ self._sync_engine: Engine = self._init(Connection.SYNC)
13
+
14
+ @overload
15
+ def _init(self, connection: Literal[Connection.ASYNC]) -> AsyncEngine: ...
16
+ @overload
17
+ def _init(self, connection: Literal[Connection.SYNC]) -> Engine: ...
18
+ def _init(
19
+ self, connection: Connection = Connection.ASYNC
20
+ ) -> Union[AsyncEngine, Engine]:
21
+ url = self.config.connection.make_url(connection)
22
+
23
+ pooling_kwargs = self.config.pooling.model_dump(
24
+ exclude={
25
+ "strategy",
26
+ "prepared_statement_cache_size",
27
+ "pool_reset_on_return",
28
+ },
29
+ exclude_none=True,
30
+ )
31
+
32
+ engine_kwargs = {"echo": self.config.connection.echo, **pooling_kwargs}
33
+
34
+ if connection is Connection.ASYNC:
35
+ self._async_engine = create_async_engine(url, **engine_kwargs)
36
+ return self._async_engine
37
+ elif connection is Connection.SYNC:
38
+ self._sync_engine = create_engine(url, **engine_kwargs)
39
+ return self._sync_engine
40
+
41
+ @overload
42
+ def get(self, connection: Literal[Connection.ASYNC]) -> AsyncEngine: ...
43
+ @overload
44
+ def get(self, connection: Literal[Connection.SYNC]) -> Engine: ...
45
+ def get(
46
+ self, connection: Connection = Connection.ASYNC
47
+ ) -> Union[AsyncEngine, Engine]:
48
+ if connection is Connection.ASYNC:
49
+ return self._async_engine
50
+ elif connection is Connection.SYNC:
51
+ return self._sync_engine
52
+
53
+ def get_all(self) -> Tuple[AsyncEngine, Engine]:
54
+ return (self._async_engine, self._sync_engine)
55
+
56
+ async def dispose(self):
57
+ await self._async_engine.dispose()
58
+ self._sync_engine.dispose()